diff --git a/src/gentropy/assets/schemas/pairwise_ld.json b/src/gentropy/assets/schemas/pairwise_ld.json new file mode 100644 index 000000000..bac781ac3 --- /dev/null +++ b/src/gentropy/assets/schemas/pairwise_ld.json @@ -0,0 +1,23 @@ +{ + "fields": [ + { + "metadata": {}, + "name": "variantIdI", + "nullable": false, + "type": "string" + }, + { + "metadata": {}, + "name": "variantIdJ", + "nullable": false, + "type": "string" + }, + { + "metadata": {}, + "name": "r", + "nullable": false, + "type": "double" + } + ], + "type": "struct" +} diff --git a/src/gentropy/dataset/pairwise_ld.py b/src/gentropy/dataset/pairwise_ld.py new file mode 100644 index 000000000..9650efa32 --- /dev/null +++ b/src/gentropy/dataset/pairwise_ld.py @@ -0,0 +1,104 @@ +"""Pairwise LD dataset.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from math import sqrt +from typing import TYPE_CHECKING + +import numpy as np +from pyspark.sql import functions as f +from pyspark.sql import types as t + +from gentropy.common.schemas import parse_spark_schema +from gentropy.dataset.dataset import Dataset + +if TYPE_CHECKING: + from pyspark.sql.types import StructType + + +@dataclass +class PairwiseLD(Dataset): + """Pairwise variant correlation dataset. + + This class captures logic applied on pairwise linkage data + by validation ensuring data quality. + """ + + dimension: tuple[int, int] = field(init=False) + + def __post_init__(self: PairwiseLD) -> None: + """Validating the dataset upon creation. + + - Besides the schema, a pairwise LD table is expected have rows being a square number. + """ + row_count = self.df.count() + + assert ( + int(sqrt(row_count)) == sqrt(row_count) + ), f"The number of rows in a pairwise LD table has to be square. Got: {row_count}" + + self.dimension = (int(sqrt(row_count)), int(sqrt(row_count))) + + @classmethod + def get_schema(cls: type[PairwiseLD]) -> StructType: + """Provide the schema for the StudyIndex dataset. + + Returns: + StructType: The schema of the StudyIndex dataset. + """ + return parse_spark_schema("pairwise_ld.json") + + def overlap_with_locus(self: PairwiseLD, locus_variants: list[str]) -> PairwiseLD: + """Subset pairwise LD table with locus. + + Args: + locus_variants (list[str]): List of variants found in the locus. + + Returns: + PairwiseLD: _description_ + """ + return PairwiseLD( + _df=( + self.df.filter( + f.col("variantIdI").isin(locus_variants) + & f.col("variantIdJ").isin(locus_variants) + ) + ), + _schema=PairwiseLD.get_schema(), + ) + + def r_to_numpy_matrix(self) -> np.ndarray: + """Convert pairwise LD to a numpy square matrix. + + Returns: + np.ndarray: 2D square matrix with r values. + """ + return np.array( + self.df.select( + f.split("variantIdI", "_")[1].cast(t.IntegerType()).alias("position_i"), + f.split("variantIdJ", "_")[1].cast(t.IntegerType()).alias("position_j"), + "r", + ) + .orderBy(f.col("position_i").asc(), f.col("position_j").asc()) + .select("r") + .collect() + ).reshape(self.dimension) + + def get_variant_list(self) -> list[str]: + """Return a list of unique variants from the dataset. + + Returns: + list[str]: list of variant identifiers sorted by position. + """ + return [ + row["variantId"] + for row in ( + self.df.select( + f.col("variantIdI").alias("variantId"), + f.split(f.col("variantIdI"), "_")[1] + .cast(t.IntegerType()) + .alias("position"), + ) + .orderBy(f.col("position").asc()) + .collect() + ) + ] diff --git a/src/gentropy/datasource/gnomad/ld.py b/src/gentropy/datasource/gnomad/ld.py index 1d2d6c18f..a19cbb06b 100644 --- a/src/gentropy/datasource/gnomad/ld.py +++ b/src/gentropy/datasource/gnomad/ld.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING import hail as hl +import numpy as np import pyspark.sql.functions as f from hail.linalg import BlockMatrix from pyspark.sql import Window @@ -16,7 +17,7 @@ from gentropy.dataset.ld_index import LDIndex if TYPE_CHECKING: - from pyspark.sql import DataFrame + from pyspark.sql import DataFrame, Row @dataclass @@ -34,6 +35,7 @@ class GnomADLDMatrix: ld_matrix_template: str = "gs://gcp-public-data--gnomad/release/2.1.1/ld/gnomad.genomes.r2.1.1.{POP}.common.adj.ld.bm" ld_index_raw_template: str = "gs://gcp-public-data--gnomad/release/2.1.1/ld/gnomad.genomes.r2.1.1.{POP}.common.ld.variant_indices.ht" + liftover_ht_path: str = "gs://gcp-public-data--gnomad/release/2.1.1/liftover_grch38/ht/genomes/gnomad.genomes.r2.1.1.sites.liftover_grch38.ht" grch37_to_grch38_chain_path: str = ( "gs://hail-common/references/grch37_to_grch38.over.chain.gz" ) @@ -218,9 +220,9 @@ def _resolve_variant_indices( DataFrame: Dataframe with variant IDs instead of `i` and `j` indices """ ld_index_i = ld_index.selectExpr( - "idx as i", "variantId as variantId_i", "chromosome" + "idx as i", "variantId as variantIdI", "chromosome" ) - ld_index_j = ld_index.selectExpr("idx as j", "variantId as variantId_j") + ld_index_j = ld_index.selectExpr("idx as j", "variantId as variantIdJ") return ( ld_matrix.join(ld_index_i, on="i", how="inner") .join(ld_index_j, on="j", how="inner") @@ -238,35 +240,35 @@ def _transpose_ld_matrix(ld_matrix: DataFrame) -> DataFrame: DataFrame: Square LD matrix without diagonal duplicates Examples: - >>> df = spark.createDataFrame( - ... [ - ... (1, 1, 1.0, "1", "AFR"), - ... (1, 2, 0.5, "1", "AFR"), - ... (2, 2, 1.0, "1", "AFR"), - ... ], - ... ["variantId_i", "variantId_j", "r", "chromosome", "population"], - ... ) - >>> GnomADLDMatrix._transpose_ld_matrix(df).show() - +-----------+-----------+---+----------+----------+ - |variantId_i|variantId_j| r|chromosome|population| - +-----------+-----------+---+----------+----------+ - | 1| 2|0.5| 1| AFR| - | 1| 1|1.0| 1| AFR| - | 2| 1|0.5| 1| AFR| - | 2| 2|1.0| 1| AFR| - +-----------+-----------+---+----------+----------+ - + >>> df = spark.createDataFrame( + ... [ + ... (1, 1, 1.0, "1", "AFR"), + ... (1, 2, 0.5, "1", "AFR"), + ... (2, 2, 1.0, "1", "AFR"), + ... ], + ... ["variantIdI", "variantIdJ", "r", "chromosome", "population"], + ... ) + >>> GnomADLDMatrix._transpose_ld_matrix(df).show() + +----------+----------+---+----------+----------+ + |variantIdI|variantIdJ| r|chromosome|population| + +----------+----------+---+----------+----------+ + | 1| 2|0.5| 1| AFR| + | 1| 1|1.0| 1| AFR| + | 2| 1|0.5| 1| AFR| + | 2| 2|1.0| 1| AFR| + +----------+----------+---+----------+----------+ + """ ld_matrix_transposed = ld_matrix.selectExpr( - "variantId_i as variantId_j", - "variantId_j as variantId_i", + "variantIdI as variantIdJ", + "variantIdJ as variantIdI", "r", "chromosome", "population", ) - return ld_matrix.filter( - f.col("variantId_i") != f.col("variantId_j") - ).unionByName(ld_matrix_transposed) + return ld_matrix.filter(f.col("variantIdI") != f.col("variantIdJ")).unionByName( + ld_matrix_transposed + ) def as_ld_index( self: GnomADLDMatrix, @@ -307,8 +309,8 @@ def as_ld_index( GnomADLDMatrix._transpose_ld_matrix( reduce(lambda df1, df2: df1.unionByName(df2), ld_indices_unaggregated) ) - .withColumnRenamed("variantId_i", "variantId") - .withColumnRenamed("variantId_j", "tagVariantId") + .withColumnRenamed("variantIdI", "variantId") + .withColumnRenamed("variantIdJ", "tagVariantId") ) return LDIndex( _df=self._aggregate_ld_index_across_populations(ld_index_unaggregated), @@ -345,7 +347,6 @@ def get_ld_variants( & (f.col("position") <= end) ) .select("chromosome", "position", "variantId", "idx") - .persist() ) if ld_index_df.limit(1).count() == 0: @@ -395,7 +396,7 @@ def _extract_square_matrix( .join( ld_index_df.select( f.col("idx").alias("idx_i"), - f.col("variantId").alias("variantId_i"), + f.col("variantId").alias("variantIdI"), ), on="idx_i", how="inner", @@ -403,12 +404,12 @@ def _extract_square_matrix( .join( ld_index_df.select( f.col("idx").alias("idx_j"), - f.col("variantId").alias("variantId_j"), + f.col("variantId").alias("variantIdJ"), ), on="idx_j", how="inner", ) - .select("variantId_i", "variantId_j", "r") + .select("variantIdI", "variantIdJ", "r") ) def get_ld_matrix_slice( @@ -448,3 +449,73 @@ def get_ld_matrix_slice( .alias("r"), ) ) + + def get_locus_index( + self: GnomADLDMatrix, + study_locus_row: Row, + window_size: int = 1_000_000, + major_population: str = "nfe", + ) -> DataFrame: + """Extract hail matrix index from StudyLocus rows. + + Args: + study_locus_row (Row): Study-locus row + window_size (int): Window size to extract from gnomad matrix + major_population (str): Major population to extract from gnomad matrix, default is "nfe" + + Returns: + DataFrame: Returns the index of the gnomad matrix for the locus + + """ + chromosome = str("chr" + study_locus_row["chromosome"]) + start = study_locus_row["position"] - window_size // 2 + end = study_locus_row["position"] + window_size // 2 + + liftover_ht = hl.read_table(self.liftover_ht_path) + liftover_ht = ( + liftover_ht.filter( + (liftover_ht.locus.contig == chromosome) + & (liftover_ht.locus.position >= start) + & (liftover_ht.locus.position <= end) + ) + .key_by() + .select("locus", "alleles", "original_locus") + .key_by("original_locus", "alleles") + .naive_coalesce(20) + ) + + hail_index = hl.read_table( + self.ld_index_raw_template.format(POP=major_population) + ) + + joined_index = ( + liftover_ht.join(hail_index, how="inner").order_by("idx").to_spark() + ) + + return joined_index + + @staticmethod + def get_numpy_matrix( + locus_index: DataFrame, + gnomad_ancestry: str = "nfe", + ) -> np.ndarray: + """Extract the LD block matrix for a locus. + + Args: + locus_index (DataFrame): hail matrix variant index table + gnomad_ancestry (str): GnomAD major ancestry label eg. `nfe` + + Returns: + np.ndarray: LD block matrix for the locus + """ + idx = [row["idx"] for row in locus_index.select("idx").collect()] + + half_matrix = ( + BlockMatrix.read( + GnomADLDMatrix.ld_matrix_template.format(POP=gnomad_ancestry) + ) + .filter(idx, idx) + .to_numpy() + ) + + return (half_matrix + half_matrix.T) - np.diag(np.diag(half_matrix)) diff --git a/tests/gentropy/dataset/test_pairwise_ld.py b/tests/gentropy/dataset/test_pairwise_ld.py new file mode 100644 index 000000000..11ebf75ca --- /dev/null +++ b/tests/gentropy/dataset/test_pairwise_ld.py @@ -0,0 +1,102 @@ +"""Testing pairwise LD dataset.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest +from gentropy.dataset.pairwise_ld import PairwiseLD +from pyspark.sql import functions as f +from pyspark.sql.window import Window + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + + +class TestPairwiseLD: + """Test suit for pairwise LD dataset and associated methods.""" + + variants = [ + "1_8_A_C", + "1_9_A_C", + "1_10_A_C", + "1_99_A_C", + ] + + @pytest.fixture(scope="class") + def mock_pairwise_ld(self: TestPairwiseLD, spark: SparkSession) -> PairwiseLD: + """Generate a mock pairwise LD dataset. + + Args: + spark (SparkSession): _description_ + + Returns: + PairwiseLD: _description_ + """ + spark = spark.builder.getOrCreate() + + data = [(v1, v2) for v1 in self.variants for v2 in self.variants] + return PairwiseLD( + _df=( + spark.createDataFrame(data, ["variantIdI", "variantIdJ"]) + .withColumn( + "r", + f.row_number() + .over(Window.partitionBy(f.lit("x")).orderBy("variantIdI")) + .cast("double"), + ) + .withColumn( + "r", + f.when(f.col("variantIdI") == f.col("variantIdJ"), 1.0).otherwise( + f.col("r") + ), + ) + .persist() + ), + _schema=PairwiseLD.get_schema(), + ) + + @staticmethod + def test_pairwise_ld__type(mock_pairwise_ld: PairwiseLD) -> None: + """Testing type.""" + assert isinstance(mock_pairwise_ld, PairwiseLD) + + def test_pariwise_ld__get_variants( + self: TestPairwiseLD, mock_pairwise_ld: PairwiseLD + ) -> None: + """Testing function that returns list of variants from the LD table. + + Args: + mock_pairwise_ld (PairwiseLD): _description_ + """ + variant_set_expected = set(self.variants) + variant_set_from_data = set(mock_pairwise_ld.get_variant_list()) + + assert variant_set_from_data == variant_set_expected + + def test_pairwise_ld__r_to_numpy_matrix__type( + self: TestPairwiseLD, mock_pairwise_ld: PairwiseLD + ) -> None: + """Testing the returned numpy array.""" + assert isinstance(mock_pairwise_ld.r_to_numpy_matrix(), np.ndarray) + + def test_pairwise_ld__r_to_numpy_matrix__dimensions( + self: TestPairwiseLD, mock_pairwise_ld: PairwiseLD + ) -> None: + """Testing the returned numpy array.""" + assert mock_pairwise_ld.r_to_numpy_matrix().shape == ( + len(self.variants), + len(self.variants), + ) + + def test_pairwise_ld__overlap_with_locus( + self: TestPairwiseLD, mock_pairwise_ld: PairwiseLD + ) -> None: + """Testing the returned numpy array.""" + variant_subset = self.variants[1:3] + + assert ( + mock_pairwise_ld.overlap_with_locus(variant_subset).df.count() + == len(variant_subset) ** 2 + )