diff --git a/src/gentropy/datasource/gnomad/ld.py b/src/gentropy/datasource/gnomad/ld.py index 4d4007d0f..1d2d6c18f 100644 --- a/src/gentropy/datasource/gnomad/ld.py +++ b/src/gentropy/datasource/gnomad/ld.py @@ -7,12 +7,10 @@ from typing import TYPE_CHECKING import hail as hl -import numpy as np import pyspark.sql.functions as f from hail.linalg import BlockMatrix from pyspark.sql import Window -from gentropy.common.session import Session from gentropy.common.spark_helpers import get_top_ranked_in_window, get_value_from_row from gentropy.common.utils import _liftover_loci, convert_gnomad_position_to_ensembl from gentropy.dataset.ld_index import LDIndex @@ -220,9 +218,9 @@ def _resolve_variant_indices( DataFrame: Dataframe with variant IDs instead of `i` and `j` indices """ ld_index_i = ld_index.selectExpr( - "idx as i", "variantId as variantIdI", "chromosome" + "idx as i", "variantId as variantId_i", "chromosome" ) - ld_index_j = ld_index.selectExpr("idx as j", "variantId as variantIdJ") + ld_index_j = ld_index.selectExpr("idx as j", "variantId as variantId_j") return ( ld_matrix.join(ld_index_i, on="i", how="inner") .join(ld_index_j, on="j", how="inner") @@ -240,35 +238,35 @@ def _transpose_ld_matrix(ld_matrix: DataFrame) -> DataFrame: DataFrame: Square LD matrix without diagonal duplicates Examples: - >>> df = spark.createDataFrame( - ... [ - ... (1, 1, 1.0, "1", "AFR"), - ... (1, 2, 0.5, "1", "AFR"), - ... (2, 2, 1.0, "1", "AFR"), - ... ], - ... ["variantIdI", "variantIdJ", "r", "chromosome", "population"], - ... ) - >>> GnomADLDMatrix._transpose_ld_matrix(df).show() - +----------+----------+---+----------+----------+ - |variantIdI|variantIdJ| r|chromosome|population| - +----------+----------+---+----------+----------+ - | 1| 2|0.5| 1| AFR| - | 1| 1|1.0| 1| AFR| - | 2| 1|0.5| 1| AFR| - | 2| 2|1.0| 1| AFR| - +----------+----------+---+----------+----------+ - + >>> df = spark.createDataFrame( + ... [ + ... (1, 1, 1.0, "1", "AFR"), + ... (1, 2, 0.5, "1", "AFR"), + ... (2, 2, 1.0, "1", "AFR"), + ... ], + ... ["variantId_i", "variantId_j", "r", "chromosome", "population"], + ... ) + >>> GnomADLDMatrix._transpose_ld_matrix(df).show() + +-----------+-----------+---+----------+----------+ + |variantId_i|variantId_j| r|chromosome|population| + +-----------+-----------+---+----------+----------+ + | 1| 2|0.5| 1| AFR| + | 1| 1|1.0| 1| AFR| + | 2| 1|0.5| 1| AFR| + | 2| 2|1.0| 1| AFR| + +-----------+-----------+---+----------+----------+ + """ ld_matrix_transposed = ld_matrix.selectExpr( - "variantIdI as variantIdJ", - "variantIdJ as variantIdI", + "variantId_i as variantId_j", + "variantId_j as variantId_i", "r", "chromosome", "population", ) - return ld_matrix.filter(f.col("variantIdI") != f.col("variantIdJ")).unionByName( - ld_matrix_transposed - ) + return ld_matrix.filter( + f.col("variantId_i") != f.col("variantId_j") + ).unionByName(ld_matrix_transposed) def as_ld_index( self: GnomADLDMatrix, @@ -309,8 +307,8 @@ def as_ld_index( GnomADLDMatrix._transpose_ld_matrix( reduce(lambda df1, df2: df1.unionByName(df2), ld_indices_unaggregated) ) - .withColumnRenamed("variantIdI", "variantId") - .withColumnRenamed("variantIdJ", "tagVariantId") + .withColumnRenamed("variantId_i", "variantId") + .withColumnRenamed("variantId_j", "tagVariantId") ) return LDIndex( _df=self._aggregate_ld_index_across_populations(ld_index_unaggregated), @@ -347,6 +345,7 @@ def get_ld_variants( & (f.col("position") <= end) ) .select("chromosome", "position", "variantId", "idx") + .persist() ) if ld_index_df.limit(1).count() == 0: @@ -396,7 +395,7 @@ def _extract_square_matrix( .join( ld_index_df.select( f.col("idx").alias("idx_i"), - f.col("variantId").alias("variantIdI"), + f.col("variantId").alias("variantId_i"), ), on="idx_i", how="inner", @@ -404,12 +403,12 @@ def _extract_square_matrix( .join( ld_index_df.select( f.col("idx").alias("idx_j"), - f.col("variantId").alias("variantIdJ"), + f.col("variantId").alias("variantId_j"), ), on="idx_j", how="inner", ) - .select("variantIdI", "variantIdJ", "r") + .select("variantId_i", "variantId_j", "r") ) def get_ld_matrix_slice( @@ -449,78 +448,3 @@ def get_ld_matrix_slice( .alias("r"), ) ) - - @staticmethod - def get_locus_index( - session: Session, - study_locus_row: DataFrame, - window_size: int, - ld_index_path: str, - major_population: str = "nfe", - ) -> DataFrame: - """Extract hail matrix index from StudyLocus rows. - - Args: - session (Session): Spark session - study_locus_row (DataFrame): Study-locus row - window_size (int): Window size to extract from gnomad matrix - ld_index_path (str): Optional path to the LD index parquet - major_population (str): Major population to extract from gnomad matrix, default is "nfe" - Returns: - DataFrame: Returns the index of the gnomad matrix for the locus - """ - _df = ( - study_locus_row.withColumn("start", f.col("position") - (window_size / 2)) - .withColumn("end", f.col("position") + (window_size / 2)) - .alias("_df") - ) - - _matrix_index = session.spark.read.parquet( - ld_index_path.format(POP=major_population) - ) - - _index_joined = ( - _df.alias("df") - .join( - _matrix_index.alias("matrix_index"), - (f.col("df.chromosome") == f.col("matrix_index.chromosome")) - & (f.col("df.start") <= f.col("matrix_index.position")) - & (f.col("df.end") >= f.col("matrix_index.position")), - ) - .select( - "matrix_index.chromosome", - "matrix_index.position", - "referenceAllele", - "alternateAllele", - "idx", - ) - .sort("idx") - ) - - return _index_joined - - @staticmethod - def get_locus_matrix( - locus_index: DataFrame, - gnomad_ancestry: str, - ) -> np.ndarray: - """Extract the LD block matrix for a locus. - - Args: - locus_index (DataFrame): hail matrix variant index table - gnomad_ancestry (str): GnomAD major ancestry label eg. `nfe` - - Returns: - np.ndarray: LD block matrix for the locus - """ - idx = [row["idx"] for row in locus_index.select("idx").collect()] - - half_matrix = ( - BlockMatrix.read( - GnomADLDMatrix.ld_matrix_template.format(POP=gnomad_ancestry) - ) - .filter(idx, idx) - .to_numpy() - ) - - return (half_matrix + half_matrix.T) - np.diag(np.diag(half_matrix))