Skip to content

Commit

Permalink
feat: LD index and block matrix extraction for a studyLocus (#463)
Browse files Browse the repository at this point in the history
* test: adding test for pairwiseLD

* feat: adding ld matrix extraction

* chore: merge from dev

* feat: index and block matrix extraction for studyLocus

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* chore: updating some test files to gentropy

* chore: updating tests

* chore: updating pairwise_ld_schema for tests

* chore: updating pairwise_ld tests

* chore: fix ld_pairwise tests

* chore: fix pairwise_ld tests

* chore: fix tests

* chore: fix tests

* chore: fixing typing for tests

* chore: fixing tests

* chore: fixing ld tests

* Update src/gentropy/dataset/study_index.py

Co-authored-by: Daniel Suveges <[email protected]>

* feat: moving functions to their appropriate locations and improving logic

* fix: optimise conversion of BM to NumPy

* feat: updating get_locus_index to allow for just chromosome and position inputs

* fix: suggested changes

* Update study_index.py

* fix: changes to datasource/gnomad/ld.py

* feat: updated method for ld_index extraction

* fix: sorting idx in hail

---------

Co-authored-by: Daniel Suveges <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 3, 2024
1 parent d76ebbe commit 56067e7
Show file tree
Hide file tree
Showing 4 changed files with 332 additions and 32 deletions.
23 changes: 23 additions & 0 deletions src/gentropy/assets/schemas/pairwise_ld.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"fields": [
{
"metadata": {},
"name": "variantIdI",
"nullable": false,
"type": "string"
},
{
"metadata": {},
"name": "variantIdJ",
"nullable": false,
"type": "string"
},
{
"metadata": {},
"name": "r",
"nullable": false,
"type": "double"
}
],
"type": "struct"
}
104 changes: 104 additions & 0 deletions src/gentropy/dataset/pairwise_ld.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Pairwise LD dataset."""
from __future__ import annotations

from dataclasses import dataclass, field
from math import sqrt
from typing import TYPE_CHECKING

import numpy as np
from pyspark.sql import functions as f
from pyspark.sql import types as t

from gentropy.common.schemas import parse_spark_schema
from gentropy.dataset.dataset import Dataset

if TYPE_CHECKING:
from pyspark.sql.types import StructType


@dataclass
class PairwiseLD(Dataset):
"""Pairwise variant correlation dataset.
This class captures logic applied on pairwise linkage data + by validation ensuring data quality.
"""

dimension: tuple[int, int] = field(init=False)

def __post_init__(self: PairwiseLD) -> None:
"""Validating the dataset upon creation.
- Besides the schema, a pairwise LD table is expected have rows being a square number.
"""
row_count = self.df.count()

assert (
int(sqrt(row_count)) == sqrt(row_count)
), f"The number of rows in a pairwise LD table has to be square. Got: {row_count}"

self.dimension = (int(sqrt(row_count)), int(sqrt(row_count)))

@classmethod
def get_schema(cls: type[PairwiseLD]) -> StructType:
"""Provide the schema for the StudyIndex dataset.
Returns:
StructType: The schema of the StudyIndex dataset.
"""
return parse_spark_schema("pairwise_ld.json")

def overlap_with_locus(self: PairwiseLD, locus_variants: list[str]) -> PairwiseLD:
"""Subset pairwise LD table with locus.
Args:
locus_variants (list[str]): List of variants found in the locus.
Returns:
PairwiseLD: _description_
"""
return PairwiseLD(
_df=(
self.df.filter(
f.col("variantIdI").isin(locus_variants)
& f.col("variantIdJ").isin(locus_variants)
)
),
_schema=PairwiseLD.get_schema(),
)

def r_to_numpy_matrix(self) -> np.ndarray:
"""Convert pairwise LD to a numpy square matrix.
Returns:
np.ndarray: 2D square matrix with r values.
"""
return np.array(
self.df.select(
f.split("variantIdI", "_")[1].cast(t.IntegerType()).alias("position_i"),
f.split("variantIdJ", "_")[1].cast(t.IntegerType()).alias("position_j"),
"r",
)
.orderBy(f.col("position_i").asc(), f.col("position_j").asc())
.select("r")
.collect()
).reshape(self.dimension)

def get_variant_list(self) -> list[str]:
"""Return a list of unique variants from the dataset.
Returns:
list[str]: list of variant identifiers sorted by position.
"""
return [
row["variantId"]
for row in (
self.df.select(
f.col("variantIdI").alias("variantId"),
f.split(f.col("variantIdI"), "_")[1]
.cast(t.IntegerType())
.alias("position"),
)
.orderBy(f.col("position").asc())
.collect()
)
]
135 changes: 103 additions & 32 deletions src/gentropy/datasource/gnomad/ld.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING

import hail as hl
import numpy as np
import pyspark.sql.functions as f
from hail.linalg import BlockMatrix
from pyspark.sql import Window
Expand All @@ -16,7 +17,7 @@
from gentropy.dataset.ld_index import LDIndex

if TYPE_CHECKING:
from pyspark.sql import DataFrame
from pyspark.sql import DataFrame, Row


@dataclass
Expand All @@ -34,6 +35,7 @@ class GnomADLDMatrix:

ld_matrix_template: str = "gs://gcp-public-data--gnomad/release/2.1.1/ld/gnomad.genomes.r2.1.1.{POP}.common.adj.ld.bm"
ld_index_raw_template: str = "gs://gcp-public-data--gnomad/release/2.1.1/ld/gnomad.genomes.r2.1.1.{POP}.common.ld.variant_indices.ht"
liftover_ht_path: str = "gs://gcp-public-data--gnomad/release/2.1.1/liftover_grch38/ht/genomes/gnomad.genomes.r2.1.1.sites.liftover_grch38.ht"
grch37_to_grch38_chain_path: str = (
"gs://hail-common/references/grch37_to_grch38.over.chain.gz"
)
Expand Down Expand Up @@ -218,9 +220,9 @@ def _resolve_variant_indices(
DataFrame: Dataframe with variant IDs instead of `i` and `j` indices
"""
ld_index_i = ld_index.selectExpr(
"idx as i", "variantId as variantId_i", "chromosome"
"idx as i", "variantId as variantIdI", "chromosome"
)
ld_index_j = ld_index.selectExpr("idx as j", "variantId as variantId_j")
ld_index_j = ld_index.selectExpr("idx as j", "variantId as variantIdJ")
return (
ld_matrix.join(ld_index_i, on="i", how="inner")
.join(ld_index_j, on="j", how="inner")
Expand All @@ -238,35 +240,35 @@ def _transpose_ld_matrix(ld_matrix: DataFrame) -> DataFrame:
DataFrame: Square LD matrix without diagonal duplicates
Examples:
>>> df = spark.createDataFrame(
... [
... (1, 1, 1.0, "1", "AFR"),
... (1, 2, 0.5, "1", "AFR"),
... (2, 2, 1.0, "1", "AFR"),
... ],
... ["variantId_i", "variantId_j", "r", "chromosome", "population"],
... )
>>> GnomADLDMatrix._transpose_ld_matrix(df).show()
+-----------+-----------+---+----------+----------+
|variantId_i|variantId_j| r|chromosome|population|
+-----------+-----------+---+----------+----------+
| 1| 2|0.5| 1| AFR|
| 1| 1|1.0| 1| AFR|
| 2| 1|0.5| 1| AFR|
| 2| 2|1.0| 1| AFR|
+-----------+-----------+---+----------+----------+
<BLANKLINE>
>>> df = spark.createDataFrame(
... [
... (1, 1, 1.0, "1", "AFR"),
... (1, 2, 0.5, "1", "AFR"),
... (2, 2, 1.0, "1", "AFR"),
... ],
... ["variantIdI", "variantIdJ", "r", "chromosome", "population"],
... )
>>> GnomADLDMatrix._transpose_ld_matrix(df).show()
+----------+----------+---+----------+----------+
|variantIdI|variantIdJ| r|chromosome|population|
+----------+----------+---+----------+----------+
| 1| 2|0.5| 1| AFR|
| 1| 1|1.0| 1| AFR|
| 2| 1|0.5| 1| AFR|
| 2| 2|1.0| 1| AFR|
+----------+----------+---+----------+----------+
<BLANKLINE>
"""
ld_matrix_transposed = ld_matrix.selectExpr(
"variantId_i as variantId_j",
"variantId_j as variantId_i",
"variantIdI as variantIdJ",
"variantIdJ as variantIdI",
"r",
"chromosome",
"population",
)
return ld_matrix.filter(
f.col("variantId_i") != f.col("variantId_j")
).unionByName(ld_matrix_transposed)
return ld_matrix.filter(f.col("variantIdI") != f.col("variantIdJ")).unionByName(
ld_matrix_transposed
)

def as_ld_index(
self: GnomADLDMatrix,
Expand Down Expand Up @@ -307,8 +309,8 @@ def as_ld_index(
GnomADLDMatrix._transpose_ld_matrix(
reduce(lambda df1, df2: df1.unionByName(df2), ld_indices_unaggregated)
)
.withColumnRenamed("variantId_i", "variantId")
.withColumnRenamed("variantId_j", "tagVariantId")
.withColumnRenamed("variantIdI", "variantId")
.withColumnRenamed("variantIdJ", "tagVariantId")
)
return LDIndex(
_df=self._aggregate_ld_index_across_populations(ld_index_unaggregated),
Expand Down Expand Up @@ -345,7 +347,6 @@ def get_ld_variants(
& (f.col("position") <= end)
)
.select("chromosome", "position", "variantId", "idx")
.persist()
)

if ld_index_df.limit(1).count() == 0:
Expand Down Expand Up @@ -395,20 +396,20 @@ def _extract_square_matrix(
.join(
ld_index_df.select(
f.col("idx").alias("idx_i"),
f.col("variantId").alias("variantId_i"),
f.col("variantId").alias("variantIdI"),
),
on="idx_i",
how="inner",
)
.join(
ld_index_df.select(
f.col("idx").alias("idx_j"),
f.col("variantId").alias("variantId_j"),
f.col("variantId").alias("variantIdJ"),
),
on="idx_j",
how="inner",
)
.select("variantId_i", "variantId_j", "r")
.select("variantIdI", "variantIdJ", "r")
)

def get_ld_matrix_slice(
Expand Down Expand Up @@ -448,3 +449,73 @@ def get_ld_matrix_slice(
.alias("r"),
)
)

def get_locus_index(
self: GnomADLDMatrix,
study_locus_row: Row,
window_size: int = 1_000_000,
major_population: str = "nfe",
) -> DataFrame:
"""Extract hail matrix index from StudyLocus rows.
Args:
study_locus_row (Row): Study-locus row
window_size (int): Window size to extract from gnomad matrix
major_population (str): Major population to extract from gnomad matrix, default is "nfe"
Returns:
DataFrame: Returns the index of the gnomad matrix for the locus
"""
chromosome = str("chr" + study_locus_row["chromosome"])
start = study_locus_row["position"] - window_size // 2
end = study_locus_row["position"] + window_size // 2

liftover_ht = hl.read_table(self.liftover_ht_path)
liftover_ht = (
liftover_ht.filter(
(liftover_ht.locus.contig == chromosome)
& (liftover_ht.locus.position >= start)
& (liftover_ht.locus.position <= end)
)
.key_by()
.select("locus", "alleles", "original_locus")
.key_by("original_locus", "alleles")
.naive_coalesce(20)
)

hail_index = hl.read_table(
self.ld_index_raw_template.format(POP=major_population)
)

joined_index = (
liftover_ht.join(hail_index, how="inner").order_by("idx").to_spark()
)

return joined_index

@staticmethod
def get_numpy_matrix(
locus_index: DataFrame,
gnomad_ancestry: str = "nfe",
) -> np.ndarray:
"""Extract the LD block matrix for a locus.
Args:
locus_index (DataFrame): hail matrix variant index table
gnomad_ancestry (str): GnomAD major ancestry label eg. `nfe`
Returns:
np.ndarray: LD block matrix for the locus
"""
idx = [row["idx"] for row in locus_index.select("idx").collect()]

half_matrix = (
BlockMatrix.read(
GnomADLDMatrix.ld_matrix_template.format(POP=gnomad_ancestry)
)
.filter(idx, idx)
.to_numpy()
)

return (half_matrix + half_matrix.T) - np.diag(np.diag(half_matrix))
Loading

0 comments on commit 56067e7

Please sign in to comment.