Skip to content

Commit

Permalink
feat: add fine-mapping of one study_locus_row
Browse files Browse the repository at this point in the history
  • Loading branch information
addramir committed Apr 3, 2024
1 parent 48592e7 commit 58f3eb9
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 54 deletions.
111 changes: 57 additions & 54 deletions src/gentropy/susie_finemapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

import numpy as np
import pyspark.sql.functions as f
from pyspark.sql import DataFrame, Window
from pyspark.sql import DataFrame, Row, Window

from gentropy.common.session import Session
from gentropy.dataset.study_index import StudyIndex
from gentropy.dataset.study_locus import StudyLocus
from gentropy.dataset.summary_statistics import SummaryStatistics
from gentropy.datasource.gnomad.ld import GnomADLDMatrix
from gentropy.method.susie_inf import SUSIE_inf


Expand All @@ -26,26 +27,28 @@ class SusieFineMapperStep:
def susie_finemapper_chr_pos(
GWAS: SummaryStatistics,
session: Session,
chromosome: str,
position: int,
_studyId: str,
study_locus_row: Row,
study_index: StudyIndex,
window: int = 1_000_000,
L: int = 10,
) -> StudyLocus:
"""Susie fine-mapper function that uses Summary Statstics, chromosome and position as inputs.
Args:
GWAS (SummaryStatistics): GWAS summary statistics
session (Session): Spark session
chromosome (str): chromosome
position (int): position
_studyId (str): study ID
study_locus_row (Row): StudyLocus row
study_index (StudyIndex): StudyIndex object
window (int): window size for fine-mapping
L (int): number of causal variants
Returns:
StudyLocus: StudyLocus object with fine-mapped credible sets
"""
chromosome = study_locus_row.chromosome
position = study_locus_row.position
_studyId = study_locus_row.studyId

study_index_df = study_index._df
study_index_df = study_index_df.filter(study_index_df.studyId == _studyId)
_major_population = study_index_df.select(
Expand All @@ -71,59 +74,57 @@ def susie_finemapper_chr_pos(
& (f.col("position") <= position + (window / 2))
).withColumn("z", f.col("beta") / f.col("standardError"))

_z = np.array([row["z"] for row in GWAS_df.select("z").collect()])

# # Extract summary statistics
# _ss = (
# SummaryStatistics.get_locus_sumstats(session, window, _locus)
# .withColumn("z", f.col("beta") / f.col("standardError"))
# .withColumn("ref", f.split(f.col("variantId"), "_").getItem(2))
# .withColumn("alt", f.split(f.col("variantId"), "_").getItem(3))
# .select(
# "variantId",
# f.col("chromosome").alias("chr"),
# f.col("position").alias("pos"),
# "ref",
# "alt",
# "beta",
# "pValueMantissa",
# "pValueExponent",
# "effectAlleleFrequencyFromSource",
# "standardError",
# "z",
# )
# )

# Extract LD index
# _index = GnomADLDMatrix.get_locus_index(
# session, locus, window_size=window, major_population=_major_population
# )
# _join = (
# _ss.join(
# _index.alias("_index"),
# on=(
# (_ss["chr"] == _index["chromosome"])
# & (_ss["pos"] == _index["position"])
# & (_ss["ref"] == _index["referenceAllele"])
# & (_ss["alt"] == _index["alternateAllele"])
# ),
# )
# .drop("ref", "alt", "chr", "pos")
# .sort("idx")
# )

# Extracting z-scores and LD matrix, then running SuSiE-inf
# _ld = GnomADLDMatrix.get_locus_matrix(_join, gnomad_ancestry=_major_population)

_ld = 1
susie_output = SUSIE_inf.susie_inf(z=_z, LD=_ld, L=10)
ld_index = (
GnomADLDMatrix()
.get_locus_index(
study_locus_row=study_locus_row,
window_size=window,
major_population=_major_population,
)
.withColumn(
"variantId",
f.concat(
f.lit(chromosome),
f.lit("_"),
f.col("`locus.position`"),
f.lit("_"),
f.col("alleles").getItem(0),
f.lit("_"),
f.col("alleles").getItem(1),
).cast("string"),
)
)

gnomad_ld = GnomADLDMatrix.get_numpy_matrix(
ld_index, gnomad_ancestry=_major_population
)

GWAS_df = GWAS_df.toPandas()
ld_index = ld_index.toPandas()
ld_index = ld_index.reset_index()

# Filtering out the variants that are not in the LD matrix, we don't need them
df_columns = GWAS_df.columns
GWAS_df = GWAS_df.merge(ld_index, on="variantId", how="inner")
GWAS_df = GWAS_df[df_columns].reset_index()

merged_df = GWAS_df.merge(
ld_index, left_on="variantId", right_on="variantId", how="inner"
)
indices = merged_df["index_y"].values

ld_to_fm = gnomad_ld[indices][:, indices]
z_to_fm = GWAS_df["z"].values

susie_output = SUSIE_inf.susie_inf(z=z_to_fm, LD=ld_to_fm, L=L)
variant_index = session.spark.createDataFrame(GWAS_df)

return SusieFineMapperStep.susie_inf_to_studylocus(
susie_output=susie_output,
session=session,
_studyId=_studyId,
_region=_region,
variant_index=GWAS,
variant_index=variant_index,
)

@staticmethod
Expand Down Expand Up @@ -151,6 +152,7 @@ def susie_inf_to_studylocus(
variants = np.array(
[row["variantId"] for row in variant_index.select("variantId").collect()]
).reshape(-1, 1)

PIPs = susie_output["PIP"]
lbfs = susie_output["lbf_variable"]
mu = susie_output["mu"]
Expand Down Expand Up @@ -181,6 +183,7 @@ def susie_inf_to_studylocus(
win = Window.rowsBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)

cred_set = (
session.spark.createDataFrame(
cred_set.tolist(),
Expand Down
1 change: 1 addition & 0 deletions tests/gentropy/method/test_susie_inf.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def test_SUSIE_inf_convert_to_study_locus(
est_tausq=False,
)
gwas_df = sample_summary_statistics._df.limit(21)

L1 = SusieFineMapperStep.susie_inf_to_studylocus(
susie_output=susie_output,
session=session,
Expand Down

0 comments on commit 58f3eb9

Please sign in to comment.