Skip to content

Commit

Permalink
chore: reverting changes to ld.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel-Considine committed Mar 20, 2024
1 parent c168bab commit b23d5a9
Showing 1 changed file with 31 additions and 107 deletions.
138 changes: 31 additions & 107 deletions src/gentropy/datasource/gnomad/ld.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
from typing import TYPE_CHECKING

import hail as hl
import numpy as np
import pyspark.sql.functions as f
from hail.linalg import BlockMatrix
from pyspark.sql import Window

from gentropy.common.session import Session
from gentropy.common.spark_helpers import get_top_ranked_in_window, get_value_from_row
from gentropy.common.utils import _liftover_loci, convert_gnomad_position_to_ensembl
from gentropy.dataset.ld_index import LDIndex
Expand Down Expand Up @@ -220,9 +218,9 @@ def _resolve_variant_indices(
DataFrame: Dataframe with variant IDs instead of `i` and `j` indices
"""
ld_index_i = ld_index.selectExpr(
"idx as i", "variantId as variantIdI", "chromosome"
"idx as i", "variantId as variantId_i", "chromosome"
)
ld_index_j = ld_index.selectExpr("idx as j", "variantId as variantIdJ")
ld_index_j = ld_index.selectExpr("idx as j", "variantId as variantId_j")
return (
ld_matrix.join(ld_index_i, on="i", how="inner")
.join(ld_index_j, on="j", how="inner")
Expand All @@ -240,35 +238,35 @@ def _transpose_ld_matrix(ld_matrix: DataFrame) -> DataFrame:
DataFrame: Square LD matrix without diagonal duplicates
Examples:
>>> df = spark.createDataFrame(
... [
... (1, 1, 1.0, "1", "AFR"),
... (1, 2, 0.5, "1", "AFR"),
... (2, 2, 1.0, "1", "AFR"),
... ],
... ["variantIdI", "variantIdJ", "r", "chromosome", "population"],
... )
>>> GnomADLDMatrix._transpose_ld_matrix(df).show()
+----------+----------+---+----------+----------+
|variantIdI|variantIdJ| r|chromosome|population|
+----------+----------+---+----------+----------+
| 1| 2|0.5| 1| AFR|
| 1| 1|1.0| 1| AFR|
| 2| 1|0.5| 1| AFR|
| 2| 2|1.0| 1| AFR|
+----------+----------+---+----------+----------+
<BLANKLINE>
>>> df = spark.createDataFrame(
... [
... (1, 1, 1.0, "1", "AFR"),
... (1, 2, 0.5, "1", "AFR"),
... (2, 2, 1.0, "1", "AFR"),
... ],
... ["variantId_i", "variantId_j", "r", "chromosome", "population"],
... )
>>> GnomADLDMatrix._transpose_ld_matrix(df).show()
+-----------+-----------+---+----------+----------+
|variantId_i|variantId_j| r|chromosome|population|
+-----------+-----------+---+----------+----------+
| 1| 2|0.5| 1| AFR|
| 1| 1|1.0| 1| AFR|
| 2| 1|0.5| 1| AFR|
| 2| 2|1.0| 1| AFR|
+-----------+-----------+---+----------+----------+
<BLANKLINE>
"""
ld_matrix_transposed = ld_matrix.selectExpr(
"variantIdI as variantIdJ",
"variantIdJ as variantIdI",
"variantId_i as variantId_j",
"variantId_j as variantId_i",
"r",
"chromosome",
"population",
)
return ld_matrix.filter(f.col("variantIdI") != f.col("variantIdJ")).unionByName(
ld_matrix_transposed
)
return ld_matrix.filter(
f.col("variantId_i") != f.col("variantId_j")
).unionByName(ld_matrix_transposed)

def as_ld_index(
self: GnomADLDMatrix,
Expand Down Expand Up @@ -309,8 +307,8 @@ def as_ld_index(
GnomADLDMatrix._transpose_ld_matrix(
reduce(lambda df1, df2: df1.unionByName(df2), ld_indices_unaggregated)
)
.withColumnRenamed("variantIdI", "variantId")
.withColumnRenamed("variantIdJ", "tagVariantId")
.withColumnRenamed("variantId_i", "variantId")
.withColumnRenamed("variantId_j", "tagVariantId")
)
return LDIndex(
_df=self._aggregate_ld_index_across_populations(ld_index_unaggregated),
Expand Down Expand Up @@ -347,6 +345,7 @@ def get_ld_variants(
& (f.col("position") <= end)
)
.select("chromosome", "position", "variantId", "idx")
.persist()
)

if ld_index_df.limit(1).count() == 0:
Expand Down Expand Up @@ -396,20 +395,20 @@ def _extract_square_matrix(
.join(
ld_index_df.select(
f.col("idx").alias("idx_i"),
f.col("variantId").alias("variantIdI"),
f.col("variantId").alias("variantId_i"),
),
on="idx_i",
how="inner",
)
.join(
ld_index_df.select(
f.col("idx").alias("idx_j"),
f.col("variantId").alias("variantIdJ"),
f.col("variantId").alias("variantId_j"),
),
on="idx_j",
how="inner",
)
.select("variantIdI", "variantIdJ", "r")
.select("variantId_i", "variantId_j", "r")
)

def get_ld_matrix_slice(
Expand Down Expand Up @@ -449,78 +448,3 @@ def get_ld_matrix_slice(
.alias("r"),
)
)

@staticmethod
def get_locus_index(
session: Session,
study_locus_row: DataFrame,
window_size: int,
ld_index_path: str,
major_population: str = "nfe",
) -> DataFrame:
"""Extract hail matrix index from StudyLocus rows.
Args:
session (Session): Spark session
study_locus_row (DataFrame): Study-locus row
window_size (int): Window size to extract from gnomad matrix
ld_index_path (str): Optional path to the LD index parquet
major_population (str): Major population to extract from gnomad matrix, default is "nfe"
Returns:
DataFrame: Returns the index of the gnomad matrix for the locus
"""
_df = (
study_locus_row.withColumn("start", f.col("position") - (window_size / 2))
.withColumn("end", f.col("position") + (window_size / 2))
.alias("_df")
)

_matrix_index = session.spark.read.parquet(
ld_index_path.format(POP=major_population)
)

_index_joined = (
_df.alias("df")
.join(
_matrix_index.alias("matrix_index"),
(f.col("df.chromosome") == f.col("matrix_index.chromosome"))
& (f.col("df.start") <= f.col("matrix_index.position"))
& (f.col("df.end") >= f.col("matrix_index.position")),
)
.select(
"matrix_index.chromosome",
"matrix_index.position",
"referenceAllele",
"alternateAllele",
"idx",
)
.sort("idx")
)

return _index_joined

@staticmethod
def get_locus_matrix(
locus_index: DataFrame,
gnomad_ancestry: str,
) -> np.ndarray:
"""Extract the LD block matrix for a locus.
Args:
locus_index (DataFrame): hail matrix variant index table
gnomad_ancestry (str): GnomAD major ancestry label eg. `nfe`
Returns:
np.ndarray: LD block matrix for the locus
"""
idx = [row["idx"] for row in locus_index.select("idx").collect()]

half_matrix = (
BlockMatrix.read(
GnomADLDMatrix.ld_matrix_template.format(POP=gnomad_ancestry)
)
.filter(idx, idx)
.to_numpy()
)

return (half_matrix + half_matrix.T) - np.diag(np.diag(half_matrix))

0 comments on commit b23d5a9

Please sign in to comment.