diff --git a/src/gentropy/datasource/gnomad/ld.py b/src/gentropy/datasource/gnomad/ld.py index 99b7b7ad2..752a05abb 100644 --- a/src/gentropy/datasource/gnomad/ld.py +++ b/src/gentropy/datasource/gnomad/ld.py @@ -12,13 +12,12 @@ from hail.linalg import BlockMatrix from pyspark.sql import Window -from gentropy.common.session import Session from gentropy.common.spark_helpers import get_top_ranked_in_window, get_value_from_row from gentropy.common.utils import _liftover_loci, convert_gnomad_position_to_ensembl from gentropy.dataset.ld_index import LDIndex if TYPE_CHECKING: - from pyspark.sql import DataFrame + from pyspark.sql import DataFrame, Row @dataclass @@ -36,6 +35,7 @@ class GnomADLDMatrix: ld_matrix_template: str = "gs://gcp-public-data--gnomad/release/2.1.1/ld/gnomad.genomes.r2.1.1.{POP}.common.adj.ld.bm" ld_index_raw_template: str = "gs://gcp-public-data--gnomad/release/2.1.1/ld/gnomad.genomes.r2.1.1.{POP}.common.ld.variant_indices.ht" + liftover_ht_path: str = "gs://gcp-public-data--gnomad/release/2.1.1/liftover_grch38/ht/genomes/gnomad.genomes.r2.1.1.sites.liftover_grch38.ht" grch37_to_grch38_chain_path: str = ( "gs://hail-common/references/grch37_to_grch38.over.chain.gz" ) @@ -450,20 +450,16 @@ def get_ld_matrix_slice( ) ) - @staticmethod def get_locus_index( - session: Session, - study_locus_row: DataFrame, - ld_index_path: str, + self: GnomADLDMatrix, + study_locus_row: Row, window_size: int = 1_000_000, major_population: str = "nfe", ) -> DataFrame: """Extract hail matrix index from StudyLocus rows. Args: - session (Session): Spark session - study_locus_row (DataFrame): Study-locus row - ld_index_path (str): Path to the hail LD index parquet + study_locus_row (Row): Study-locus row window_size (int): Window size to extract from gnomad matrix major_population (str): Major population to extract from gnomad matrix, default is "nfe" @@ -471,40 +467,35 @@ def get_locus_index( DataFrame: Returns the index of the gnomad matrix for the locus """ - _df = ( - study_locus_row.withColumn("start", f.col("position") - (window_size / 2)) - .withColumn("end", f.col("position") + (window_size / 2)) - .alias("_df") + chromosome = str("chr" + study_locus_row["chromosome"]) + start = study_locus_row["position"] - window_size // 2 + end = study_locus_row["position"] + window_size // 2 + + liftover_ht = hl.read_table(self.liftover_ht_path) + liftover_ht = ( + liftover_ht.filter( + (liftover_ht.locus.contig == chromosome) + & (liftover_ht.locus.position >= start) + & (liftover_ht.locus.position <= end) + ) + .key_by() + .select("locus", "alleles", "original_locus") + .key_by("original_locus", "alleles") + .naive_coalesce(20) ) - _matrix_index = session.spark.read.parquet( - ld_index_path.format(POP=major_population) + hail_index = hl.read_table( + self.ld_index_raw_template.format(POP=major_population) ) - _index_joined = ( - _df.alias("df") - .join( - _matrix_index.alias("matrix_index"), - (f.col("df.chromosome") == f.col("matrix_index.chromosome")) - & (f.col("df.start") <= f.col("matrix_index.position")) - & (f.col("df.end") >= f.col("matrix_index.position")), - ) - .select( - "matrix_index.chromosome", - "matrix_index.position", - "referenceAllele", - "alternateAllele", - "idx", - ) - .sort("idx") - ) + joined_index = liftover_ht.join(hail_index, how="inner").to_spark().sort("idx") - return _index_joined + return joined_index @staticmethod - def get_locus_matrix( + def get_numpy_matrix( locus_index: DataFrame, - gnomad_ancestry: str, + gnomad_ancestry: str = "nfe", ) -> np.ndarray: """Extract the LD block matrix for a locus.