diff --git a/src/gentropy/assets/schemas/colocalisation.json b/src/gentropy/assets/schemas/colocalisation.json index 7ff7453b9..6e1163cfe 100644 --- a/src/gentropy/assets/schemas/colocalisation.json +++ b/src/gentropy/assets/schemas/colocalisation.json @@ -13,6 +13,12 @@ "type": "long", "metadata": {} }, + { + "name": "rightStudyType", + "nullable": false, + "type": "string", + "metadata": {} + }, { "name": "chromosome", "nullable": false, diff --git a/src/gentropy/assets/schemas/study_locus.json b/src/gentropy/assets/schemas/study_locus.json index 11908f687..a8d15aba6 100644 --- a/src/gentropy/assets/schemas/study_locus.json +++ b/src/gentropy/assets/schemas/study_locus.json @@ -6,6 +6,12 @@ "nullable": false, "type": "long" }, + { + "metadata": {}, + "name": "studyType", + "nullable": true, + "type": "string" + }, { "metadata": {}, "name": "variantId", diff --git a/src/gentropy/assets/schemas/study_locus_overlap.json b/src/gentropy/assets/schemas/study_locus_overlap.json index 9a8e123cd..22ba7705e 100644 --- a/src/gentropy/assets/schemas/study_locus_overlap.json +++ b/src/gentropy/assets/schemas/study_locus_overlap.json @@ -12,6 +12,12 @@ "nullable": false, "type": "long" }, + { + "metadata": {}, + "name": "rightStudyType", + "nullable": false, + "type": "string" + }, { "metadata": {}, "name": "chromosome", diff --git a/src/gentropy/colocalisation.py b/src/gentropy/colocalisation.py index 6b370d426..4f8431b98 100644 --- a/src/gentropy/colocalisation.py +++ b/src/gentropy/colocalisation.py @@ -8,7 +8,6 @@ from pyspark.sql.functions import col from gentropy.common.session import Session -from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import CredibleInterval, StudyLocus from gentropy.method.colocalisation import Coloc @@ -23,7 +22,6 @@ def __init__( self, session: Session, credible_set_path: str, - study_index_path: str, coloc_path: str, colocalisation_method: str, ) -> None: @@ -32,7 +30,6 @@ def __init__( Args: session (Session): Session object. credible_set_path (str): Input credible sets path. - study_index_path (str): Input study index path. coloc_path (str): Output Colocalisation path. colocalisation_method (str): Colocalisation method. """ @@ -47,14 +44,11 @@ def __init__( session, credible_set_path, recursiveFileLookup=True ) ) - si = StudyIndex.from_parquet( - session, study_index_path, recursiveFileLookup=True - ) # Transform overlaps = credible_set.filter_credible_set( CredibleInterval.IS95 - ).find_overlaps(si) + ).find_overlaps() colocalisation_results = colocalisation_class.colocalise(overlaps) # type: ignore # Load diff --git a/src/gentropy/dataset/colocalisation.py b/src/gentropy/dataset/colocalisation.py index c0d074ae3..94a4f09dc 100644 --- a/src/gentropy/dataset/colocalisation.py +++ b/src/gentropy/dataset/colocalisation.py @@ -91,7 +91,7 @@ def extract_maximum_coloc_probability_per_region_and_gene( self.append_study_metadata( study_locus, study_index, - metadata_cols=["studyType", "geneId"], + metadata_cols=["geneId"], colocalisation_side="right", ) # it also filters based on method and qtl type diff --git a/src/gentropy/dataset/study_locus.py b/src/gentropy/dataset/study_locus.py index e8363aa4e..2385df984 100644 --- a/src/gentropy/dataset/study_locus.py +++ b/src/gentropy/dataset/study_locus.py @@ -17,6 +17,7 @@ order_array_of_structs_by_field, ) from gentropy.common.utils import get_logsum +from gentropy.config import WindowBasedClumpingStepConfig from gentropy.dataset.dataset import Dataset from gentropy.dataset.study_locus_overlap import StudyLocusOverlap from gentropy.dataset.variant_index import VariantIndex @@ -45,7 +46,8 @@ class StudyLocusQualityCheck(Enum): PALINDROMIC_ALLELE_FLAG (str): Alleles are palindromic - cannot harmonize AMBIGUOUS_STUDY (str): Association with ambiguous study UNRESOLVED_LD (str): Variant not found in LD reference - LD_CLUMPED (str): Explained by a more significant variant in high LD (clumped) + LD_CLUMPED (str): Explained by a more significant variant in high LD + WINDOW_CLUMPED (str): Explained by a more significant variant in the same window NO_POPULATION (str): Study does not have population annotation to resolve LD NOT_QUALIFYING_LD_BLOCK (str): LD block does not contain variants at the required R^2 threshold FAILED_STUDY (str): Flagging study loci if the study has failed QC @@ -65,7 +67,8 @@ class StudyLocusQualityCheck(Enum): PALINDROMIC_ALLELE_FLAG = "Palindrome alleles - cannot harmonize" AMBIGUOUS_STUDY = "Association with ambiguous study" UNRESOLVED_LD = "Variant not found in LD reference" - LD_CLUMPED = "Explained by a more significant variant in high LD (clumped)" + LD_CLUMPED = "Explained by a more significant variant in high LD" + WINDOW_CLUMPED = "Explained by a more significant variant in the same window" NO_POPULATION = "Study does not have population annotation to resolve LD" NOT_QUALIFYING_LD_BLOCK = ( "LD block does not contain variants at the required R^2 threshold" @@ -157,6 +160,24 @@ def validate_study(self: StudyLocus, study_index: StudyIndex) -> StudyLocus: _schema=self.get_schema(), ) + def annotate_study_type(self: StudyLocus, study_index: StudyIndex) -> StudyLocus: + """Gets study type from study index and adds it to study locus. + + Args: + study_index (StudyIndex): Study index to get study type. + + Returns: + StudyLocus: Updated study locus with study type. + """ + return StudyLocus( + _df=( + self.df.drop("studyType").join( + study_index.study_type_lut(), on="studyId", how="left" + ) + ), + _schema=self.get_schema(), + ) + def validate_variant_identifiers( self: StudyLocus, variant_index: VariantIndex ) -> StudyLocus: @@ -394,6 +415,7 @@ def _align_overlapping_tags( f.col("chromosome"), f.col("tagVariantId"), f.col("studyLocusId").alias("rightStudyLocusId"), + f.col("studyType").alias("rightStudyType"), *[f.col(col).alias(f"right_{col}") for col in stats_cols], ).join(peak_overlaps, on=["chromosome", "rightStudyLocusId"], how="inner") @@ -410,6 +432,7 @@ def _align_overlapping_tags( ).select( "leftStudyLocusId", "rightStudyLocusId", + "rightStudyType", "chromosome", "tagVariantId", f.struct( @@ -504,14 +527,11 @@ def get_QC_mappings(cls: type[StudyLocus]) -> dict[str, str]: """ return {member.name: member.value for member in StudyLocusQualityCheck} - def filter_by_study_type( - self: StudyLocus, study_type: str, study_index: StudyIndex - ) -> StudyLocus: + def filter_by_study_type(self: StudyLocus, study_type: str) -> StudyLocus: """Creates a new StudyLocus dataset filtered by study type. Args: study_type (str): Study type to filter for. Can be one of `gwas`, `eqtl`, `pqtl`, `eqtl`. - study_index (StudyIndex): Study index to resolve study types. Returns: StudyLocus: Filtered study-locus dataset. @@ -523,11 +543,7 @@ def filter_by_study_type( raise ValueError( f"Study type {study_type} not supported. Supported types are: gwas, eqtl, pqtl, sqtl." ) - new_df = ( - self.df.join(study_index.study_type_lut(), on="studyId", how="inner") - .filter(f.col("studyType") == study_type) - .drop("studyType") - ) + new_df = self.df.filter(f.col("studyType") == study_type).drop("studyType") return StudyLocus( _df=new_df, _schema=self._schema, @@ -576,7 +592,7 @@ def filter_ld_set(ld_set: Column, r2_threshold: float) -> Column: ) def find_overlaps( - self: StudyLocus, study_index: StudyIndex, intra_study_overlap: bool = False + self: StudyLocus, intra_study_overlap: bool = False ) -> StudyLocusOverlap: """Calculate overlapping study-locus. @@ -584,14 +600,13 @@ def find_overlaps( appearing on the right side. Args: - study_index (StudyIndex): Study index to resolve study types. intra_study_overlap (bool): If True, finds intra-study overlaps for credible set deduplication. Default is False. Returns: StudyLocusOverlap: Pairs of overlapping study-locus with aligned tags. """ loci_to_overlap = ( - self.df.join(study_index.study_type_lut(), on="studyId", how="inner") + self.df.filter(f.col("studyType").isNotNull()) .withColumn("locus", f.explode("locus")) .select( "studyLocusId", @@ -1032,3 +1047,19 @@ def annotate_locus_statistics_boundaries( ) return self + + def window_based_clumping( + self: StudyLocus, + window_size: int = WindowBasedClumpingStepConfig().distance, + ) -> StudyLocus: + """Clump study locus by window size. + + Args: + window_size (int): Window size for clumping. + + Returns: + StudyLocus: Clumped study locus, where clumped associations are flagged. + """ + from gentropy.method.window_based_clumping import WindowBasedClumping + + return WindowBasedClumping.clump(self, window_size) diff --git a/src/gentropy/dataset/study_locus_overlap.py b/src/gentropy/dataset/study_locus_overlap.py index 5f839bd9c..d14a2da96 100644 --- a/src/gentropy/dataset/study_locus_overlap.py +++ b/src/gentropy/dataset/study_locus_overlap.py @@ -10,7 +10,6 @@ if TYPE_CHECKING: from pyspark.sql.types import StructType - from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus @@ -36,18 +35,17 @@ def get_schema(cls: type[StudyLocusOverlap]) -> StructType: @classmethod def from_associations( - cls: type[StudyLocusOverlap], study_locus: StudyLocus, study_index: StudyIndex + cls: type[StudyLocusOverlap], study_locus: StudyLocus ) -> StudyLocusOverlap: """Find the overlapping signals in a particular set of associations (StudyLocus dataset). Args: study_locus (StudyLocus): Study-locus associations to find the overlapping signals - study_index (StudyIndex): Study index to find the overlapping signals Returns: StudyLocusOverlap: Study-locus overlap dataset """ - return study_locus.find_overlaps(study_index) + return study_locus.find_overlaps() def _convert_to_square_matrix(self: StudyLocusOverlap) -> StudyLocusOverlap: """Convert the dataset to a square matrix. @@ -60,6 +58,7 @@ def _convert_to_square_matrix(self: StudyLocusOverlap) -> StudyLocusOverlap: self.df.selectExpr( "leftStudyLocusId as rightStudyLocusId", "rightStudyLocusId as leftStudyLocusId", + "rightStudyType", "tagVariantId", ) ).distinct(), diff --git a/src/gentropy/dataset/summary_statistics.py b/src/gentropy/dataset/summary_statistics.py index d0875fe85..25edbeca7 100644 --- a/src/gentropy/dataset/summary_statistics.py +++ b/src/gentropy/dataset/summary_statistics.py @@ -77,10 +77,11 @@ def window_based_clumping( from gentropy.method.window_based_clumping import WindowBasedClumping return WindowBasedClumping.clump( - self, + # Before clumping, we filter the summary statistics by p-value: + self.pvalue_filter(gwas_significance), distance=distance, - gwas_significance=gwas_significance, - ) + # After applying the clumping, we filter the clumped loci by the flag: + ).valid_rows(["WINDOW_CLUMPED"]) def locus_breaker_clumping( self: SummaryStatistics, diff --git a/src/gentropy/gwas_catalog_ingestion.py b/src/gentropy/gwas_catalog_ingestion.py index 725f1ca4d..5dab5bf16 100644 --- a/src/gentropy/gwas_catalog_ingestion.py +++ b/src/gentropy/gwas_catalog_ingestion.py @@ -3,6 +3,7 @@ from __future__ import annotations from gentropy.common.session import Session +from gentropy.config import WindowBasedClumpingStepConfig from gentropy.dataset.variant_index import VariantIndex from gentropy.datasource.gwas_catalog.associations import ( GWASCatalogCuratedAssociationsParser, @@ -30,6 +31,7 @@ def __init__( gnomad_variant_path: str, catalog_studies_out: str, catalog_associations_out: str, + distance: int = WindowBasedClumpingStepConfig().distance, gwas_catalog_study_curation_file: str | None = None, inclusion_list_path: str | None = None, ) -> None: @@ -44,6 +46,7 @@ def __init__( gnomad_variant_path (str): Path to GnomAD variants. catalog_studies_out (str): Output GWAS catalog studies path. catalog_associations_out (str): Output GWAS catalog associations path. + distance (int): Distance, within which tagging variants are collected around the semi-index. gwas_catalog_study_curation_file (str | None): file of the curation table. Optional. inclusion_list_path (str | None): optional inclusion list (parquet) """ @@ -86,4 +89,9 @@ def __init__( # Load study_index.df.write.mode(session.write_mode).parquet(catalog_studies_out) - study_locus.df.write.mode(session.write_mode).parquet(catalog_associations_out) + + ( + study_locus.window_based_clumping(distance) + .df.write.mode(session.write_mode) + .parquet(catalog_associations_out) + ) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 832023cd8..13dbb881b 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -204,7 +204,7 @@ def _generate_feature_matrix(self, write_feature_matrix: bool) -> L2GFeatureMatr ValueError: If write_feature_matrix is set to True but a path is not provided. ValueError: If dependencies to build features are not set. """ - if self.gs_curation and self.interactions and self.v2g and self.studies: + if self.gs_curation and self.interactions and self.v2g: study_locus_overlap = StudyLocus( _df=self.credible_set.df.join( f.broadcast( @@ -225,7 +225,7 @@ def _generate_feature_matrix(self, write_feature_matrix: bool) -> L2GFeatureMatr "inner", ), _schema=StudyLocus.get_schema(), - ).find_overlaps(self.studies) + ).find_overlaps() gold_standards = L2GGoldStandard.from_otg_curation( gold_standard_curation=self.gs_curation, diff --git a/src/gentropy/method/colocalisation.py b/src/gentropy/method/colocalisation.py index c3320f931..7a3a0d9c5 100644 --- a/src/gentropy/method/colocalisation.py +++ b/src/gentropy/method/colocalisation.py @@ -79,7 +79,7 @@ def colocalise( f.col("statistics.right_posteriorProbability"), ), ) - .groupBy("leftStudyLocusId", "rightStudyLocusId", "chromosome") + .groupBy("leftStudyLocusId", "rightStudyLocusId", "rightStudyType", "chromosome") .agg( f.count("*").alias("numberColocalisingVariants"), f.sum(f.col("clpp")).alias("clpp"), @@ -168,7 +168,7 @@ def colocalise( f.col("left_logBF") + f.col("right_logBF"), ) # Group by overlapping peak and generating dense vectors of log_BF: - .groupBy("chromosome", "leftStudyLocusId", "rightStudyLocusId") + .groupBy("chromosome", "leftStudyLocusId", "rightStudyLocusId", "rightStudyType") .agg( f.count("*").alias("numberColocalisingVariants"), fml.array_to_vector(f.collect_list(f.col("left_logBF"))).alias( diff --git a/src/gentropy/method/window_based_clumping.py b/src/gentropy/method/window_based_clumping.py index 629fe627e..9ef747abf 100644 --- a/src/gentropy/method/window_based_clumping.py +++ b/src/gentropy/method/window_based_clumping.py @@ -12,7 +12,7 @@ from pyspark.sql.window import Window from gentropy.config import WindowBasedClumpingStepConfig -from gentropy.dataset.study_locus import StudyLocus +from gentropy.dataset.study_locus import StudyLocus, StudyLocusQualityCheck if TYPE_CHECKING: from numpy.typing import NDArray @@ -154,22 +154,38 @@ def _prune_peak(position: NDArray[np.float64], window_size: int) -> DenseVector: @staticmethod def clump( - summary_statistics: SummaryStatistics, + unclumped_associations: SummaryStatistics | StudyLocus, distance: int = WindowBasedClumpingStepConfig().distance, - gwas_significance: float = WindowBasedClumpingStepConfig().gwas_significance, ) -> StudyLocus: - """Clump significant signals from summary statistics based on window. + """Clump single point associations from summary statistics or study locus dataset based on window. Args: - summary_statistics (SummaryStatistics): Summary statistics to be used for clumping. + unclumped_associations (SummaryStatistics | StudyLocus): Input dataset to be used for clumping. Assumes that the input dataset is already filtered for significant variants. distance (int): Distance in base pairs to be used for clumping. Defaults to 500_000. - gwas_significance (float): GWAS significance threshold. Defaults to 5e-8. Returns: - StudyLocus: clumped summary statistics (without locus collection) - - Check WindowBasedClumpingStepConfig object for default values + StudyLocus: clumped associations, where the clumped variants are flagged. """ + # Quality check expression that flags variants that are not considered lead variant: + qc_check = f.col("semiIndices")[f.col("pvRank") - 1] <= 0 + + # The quality control expression will depend on the input dataset, as the column might be already present: + qc_expression = ( + # When the column is already present and the condition is met, the value is appended to the array, otherwise keep as is: + f.when( + qc_check, + f.array_union( + f.col("qualityControls"), + f.array(f.lit(StudyLocusQualityCheck.WINDOW_CLUMPED.value)), + ), + ).otherwise(f.col("qualityControls")) + if "qualityControls" in unclumped_associations.df.columns + # If column is not there yet, initialize it with the flag value, or an empty array: + else f.when( + qc_check, f.array(f.lit(StudyLocusQualityCheck.WINDOW_CLUMPED.value)) + ).otherwise(f.array().cast(t.ArrayType(t.StringType()))) + ) + # Create window for locus clusters # - variants where the distance between subsequent variants is below the defined threshold. # - Variants are sorted by descending significance @@ -179,11 +195,8 @@ def clump( return StudyLocus( _df=( - summary_statistics - # Dropping snps below significance - all subsequent steps are done on significant variants: - .pvalue_filter(gwas_significance) - .df - # Clustering summary variants for efficient windowing (complexity reduction): + unclumped_associations.df + # Clustering variants for efficient windowing (complexity reduction): .withColumn( "cluster_id", WindowBasedClumping._cluster_peaks( @@ -207,7 +220,7 @@ def clump( ), ).otherwise(f.array()), ) - # Get semi indices only ONCE per cluster: + # Collect top loci per cluster: .withColumn( "semiIndices", f.when( @@ -230,9 +243,6 @@ def clump( ), ).otherwise(f.col("semiIndices")), ) - # Keeping semi indices only: - .filter(f.col("semiIndices")[f.col("pvRank") - 1] > 0) - .drop("pvRank", "collectedPositions", "semiIndices", "cluster_id") # Adding study-locus id: .withColumn( "studyLocusId", @@ -241,9 +251,8 @@ def clump( ), ) # Initialize QC column as array of strings: - .withColumn( - "qualityControls", f.array().cast(t.ArrayType(t.StringType())) - ) + .withColumn("qualityControls", qc_expression) + .drop("pvRank", "collectedPositions", "semiIndices", "cluster_id") ), _schema=StudyLocus.get_schema(), ) diff --git a/src/gentropy/study_locus_validation.py b/src/gentropy/study_locus_validation.py index e3d10f3db..4d1c234dc 100644 --- a/src/gentropy/study_locus_validation.py +++ b/src/gentropy/study_locus_validation.py @@ -46,6 +46,7 @@ def __init__( # Add flag for MHC region .qc_MHC_region() .validate_study(study_index) # Flagging studies not in study index + .annotate_study_type(study_index) # Add study type to study locus .qc_redundant_top_hits_from_PICS() # Flagging top hits from studies with PICS summary statistics .validate_unique_study_locus_id() # Flagging duplicated study locus ids ).persist() # we will need this for 2 types of outputs diff --git a/tests/gentropy/dataset/test_colocalisation.py b/tests/gentropy/dataset/test_colocalisation.py index 5371cf42c..8f2766fb4 100644 --- a/tests/gentropy/dataset/test_colocalisation.py +++ b/tests/gentropy/dataset/test_colocalisation.py @@ -100,10 +100,11 @@ def _setup(self: TestAppendStudyMetadata, spark: SparkSession) -> None: ) self.sample_colocalisation = Colocalisation( _df=spark.createDataFrame( - [(1, 2, "X", "COLOC", 1, 0.9)], + [(1, 2, "eqtl", "X", "COLOC", 1, 0.9)], [ "leftStudyLocusId", "rightStudyLocusId", + "rightStudyType", "chromosome", "colocalisationMethod", "numberColocalisingVariants", diff --git a/tests/gentropy/dataset/test_l2g.py b/tests/gentropy/dataset/test_l2g.py index d37ce5a4a..2523b97dd 100644 --- a/tests/gentropy/dataset/test_l2g.py +++ b/tests/gentropy/dataset/test_l2g.py @@ -70,8 +70,8 @@ def test_filter_unique_associations(spark: SparkSession) -> None: ) mock_sl_overlap_df = spark.createDataFrame( - [(1, 2, "variant2"), (1, 4, "variant4")], - "leftStudyLocusId LONG, rightStudyLocusId LONG, tagVariantId STRING", + [(1, 2, "eqtl", "variant2"), (1, 4, "eqtl", "variant4")], + "leftStudyLocusId LONG, rightStudyLocusId LONG, rightStudyType STRING, tagVariantId STRING", ) expected_df = spark.createDataFrame( diff --git a/tests/gentropy/dataset/test_l2g_feature_matrix.py b/tests/gentropy/dataset/test_l2g_feature_matrix.py index 46384239c..09460ee85 100644 --- a/tests/gentropy/dataset/test_l2g_feature_matrix.py +++ b/tests/gentropy/dataset/test_l2g_feature_matrix.py @@ -136,10 +136,11 @@ def _setup(self: TestFromFeaturesList, spark: SparkSession) -> None: ) self.sample_colocalisation = Colocalisation( _df=spark.createDataFrame( - [(1, 2, "X", "COLOC", 1, 0.9)], + [(1, 2, "eqtl", "X", "COLOC", 1, 0.9)], [ "leftStudyLocusId", "rightStudyLocusId", + "rightStudyType", "chromosome", "colocalisationMethod", "numberColocalisingVariants", diff --git a/tests/gentropy/dataset/test_study_locus.py b/tests/gentropy/dataset/test_study_locus.py index c89521b3c..51fc2ed92 100644 --- a/tests/gentropy/dataset/test_study_locus.py +++ b/tests/gentropy/dataset/test_study_locus.py @@ -43,6 +43,7 @@ { "leftStudyLocusId": 1, "rightStudyLocusId": 2, + "rightStudyType": "eqtl", "chromosome": "1", "tagVariantId": "commonTag", "statistics": { @@ -53,6 +54,7 @@ { "leftStudyLocusId": 1, "rightStudyLocusId": 2, + "rightStudyType": "eqtl", "chromosome": "1", "tagVariantId": "nonCommonTag", "statistics": { @@ -79,6 +81,7 @@ def test_find_overlaps_semantic( "studyLocusId": 1, "variantId": "lead1", "studyId": "study1", + "studyType": "gwas", "locus": [ {"variantId": "commonTag", "posteriorProbability": 0.9}, ], @@ -88,6 +91,7 @@ def test_find_overlaps_semantic( "studyLocusId": 2, "variantId": "lead2", "studyId": "study2", + "studyType": "eqtl", "locus": [ {"variantId": "commonTag", "posteriorProbability": 0.6}, {"variantId": "nonCommonTag", "posteriorProbability": 0.6}, @@ -108,6 +112,7 @@ def test_find_overlaps_semantic( "studyLocusId": 1, "variantId": "lead1", "studyId": "study1", + "studyType": "gwas", "locus": [ {"variantId": "var1", "posteriorProbability": 0.9}, ], @@ -117,6 +122,7 @@ def test_find_overlaps_semantic( "studyLocusId": 2, "variantId": "lead2", "studyId": "study2", + "studyType": "eqtl", "locus": None, "chromosome": "1", }, @@ -126,25 +132,6 @@ def test_find_overlaps_semantic( _schema=StudyLocus.get_schema(), ) - studies = StudyIndex( - _df=spark.createDataFrame( - [ - { - "studyId": "study1", - "studyType": "gwas", - "traitFromSource": "trait1", - "projectId": "project1", - }, - { - "studyId": "study2", - "studyType": "eqtl", - "traitFromSource": "trait2", - "projectId": "project2", - }, - ] - ), - _schema=StudyIndex.get_schema(), - ) expected_overlaps_df = spark.createDataFrame( expected, StudyLocusOverlap.get_schema() ) @@ -154,18 +141,14 @@ def test_find_overlaps_semantic( "statistics.right_posteriorProbability", ] assert ( - credset.find_overlaps(studies).df.select(*cols_to_compare).collect() + credset.find_overlaps().df.select(*cols_to_compare).collect() == expected_overlaps_df.select(*cols_to_compare).collect() ), "Overlaps differ from expected." -def test_find_overlaps( - mock_study_locus: StudyLocus, mock_study_index: StudyIndex -) -> None: +def test_find_overlaps(mock_study_locus: StudyLocus) -> None: """Test study locus overlaps.""" - assert isinstance( - mock_study_locus.find_overlaps(mock_study_index), StudyLocusOverlap - ) + assert isinstance(mock_study_locus.find_overlaps(), StudyLocusOverlap) @pytest.mark.parametrize( @@ -184,39 +167,22 @@ def test_filter_by_study_type( "studyLocusId": 1, "variantId": "lead1", "studyId": "study1", + "studyType": "gwas", }, { # from eqtl "studyLocusId": 2, "variantId": "lead2", "studyId": "study2", + "studyType": "eqtl", }, ], StudyLocus.get_schema(), ), _schema=StudyLocus.get_schema(), ) - studies = StudyIndex( - _df=spark.createDataFrame( - [ - { - "studyId": "study1", - "studyType": "gwas", - "traitFromSource": "trait1", - "projectId": "project1", - }, - { - "studyId": "study2", - "studyType": "eqtl", - "traitFromSource": "trait2", - "projectId": "project2", - }, - ] - ), - _schema=StudyIndex.get_schema(), - ) - observed = sl.filter_by_study_type(study_type, studies) + observed = sl.filter_by_study_type(study_type) assert observed.df.count() == expected_sl_count @@ -783,6 +749,74 @@ def test_study_validation_correctness(self: TestStudyLocusValidation) -> None: ) == 1 +class TestStudyLocusWindowClumping: + """Testing window-based clumping on study locus.""" + + TEST_DATASET = [ + ("s1", "c1", 1, -1), + ("s1", "c1", 2, -2), + ("s1", "c1", 3, -3), + ("s2", "c2", 2, -2), + ("s3", "c2", 2, -2), + ] + + TEST_SCHEMA = t.StructType( + [ + t.StructField("studyId", t.StringType(), False), + t.StructField("chromosome", t.StringType(), False), + t.StructField("position", t.IntegerType(), False), + t.StructField("pValueExponent", t.IntegerType(), False), + ] + ) + + @pytest.fixture(autouse=True) + def _setup(self: TestStudyLocusWindowClumping, spark: SparkSession) -> None: + """Setup study locus for testing.""" + self.study_locus = StudyLocus( + _df=( + spark.createDataFrame( + self.TEST_DATASET, schema=self.TEST_SCHEMA + ).withColumns( + { + "studyLocusId": f.monotonically_increasing_id().cast( + t.LongType() + ), + "pValueMantissa": f.lit(1).cast(t.FloatType()), + "variantId": f.concat( + f.lit("v"), + f.monotonically_increasing_id().cast(t.StringType()), + ), + } + ) + ), + _schema=StudyLocus.get_schema(), + ) + + def test_clump_return_type(self: TestStudyLocusWindowClumping) -> None: + """Testing if the clumping returns the right type.""" + assert isinstance(self.study_locus.window_based_clumping(3), StudyLocus) + + def test_clump_no_data_loss(self: TestStudyLocusWindowClumping) -> None: + """Testing if the clumping returns same number of rows.""" + assert ( + self.study_locus.window_based_clumping(3).df.count() + == self.study_locus.df.count() + ) + + def test_correct_flag(self: TestStudyLocusWindowClumping) -> None: + """Testing if the clumping flags are for variants.""" + assert ( + self.study_locus.window_based_clumping(3) + .df.filter( + f.array_contains( + f.col("qualityControls"), + StudyLocusQualityCheck.WINDOW_CLUMPED.value, + ) + ) + .count() + ) == 2 + + def test_build_feature_matrix( mock_study_locus: StudyLocus, mock_colocalisation: Colocalisation, diff --git a/tests/gentropy/dataset/test_study_locus_overlap.py b/tests/gentropy/dataset/test_study_locus_overlap.py index e26b59c30..7e591df30 100644 --- a/tests/gentropy/dataset/test_study_locus_overlap.py +++ b/tests/gentropy/dataset/test_study_locus_overlap.py @@ -19,19 +19,19 @@ def test_convert_to_square_matrix(spark: SparkSession) -> None: mock_sl_overlap = StudyLocusOverlap( _df=spark.createDataFrame( [ - (1, 2, "variant2"), + (1, 2, "eqtl", "variant2"), ], - "leftStudyLocusId LONG, rightStudyLocusId LONG, tagVariantId STRING", + "leftStudyLocusId LONG, rightStudyLocusId LONG, rightStudyType STRING, tagVariantId STRING", ), _schema=StudyLocusOverlap.get_schema(), ) expected_df = spark.createDataFrame( [ - (1, 2, "variant2"), - (2, 1, "variant2"), + (1, 2, "eqtl", "variant2"), + (2, 1, "eqtl", "variant2"), ], - "leftStudyLocusId LONG, rightStudyLocusId LONG, tagVariantId STRING", + "leftStudyLocusId LONG, rightStudyLocusId LONG, rightStudyType STRING, tagVariantId STRING", ) observed_df = mock_sl_overlap._convert_to_square_matrix().df diff --git a/tests/gentropy/dataset/test_study_locus_overlaps.py b/tests/gentropy/dataset/test_study_locus_overlaps.py index bd3415959..745f07ed2 100644 --- a/tests/gentropy/dataset/test_study_locus_overlaps.py +++ b/tests/gentropy/dataset/test_study_locus_overlaps.py @@ -13,8 +13,6 @@ if TYPE_CHECKING: from pyspark.sql import SparkSession - from gentropy.dataset.study_index import StudyIndex - def test_study_locus_overlap_creation( mock_study_locus_overlap: StudyLocusOverlap, @@ -23,11 +21,9 @@ def test_study_locus_overlap_creation( assert isinstance(mock_study_locus_overlap, StudyLocusOverlap) -def test_study_locus_overlap_from_associations( - mock_study_locus: StudyLocus, mock_study_index: StudyIndex -) -> None: +def test_study_locus_overlap_from_associations(mock_study_locus: StudyLocus) -> None: """Test colocalisation creation from mock associations.""" - overlaps = StudyLocusOverlap.from_associations(mock_study_locus, mock_study_index) + overlaps = StudyLocusOverlap.from_associations(mock_study_locus) assert isinstance(overlaps, StudyLocusOverlap) diff --git a/tests/gentropy/method/test_colocalisation_method.py b/tests/gentropy/method/test_colocalisation_method.py index d6798d831..e292784c1 100644 --- a/tests/gentropy/method/test_colocalisation_method.py +++ b/tests/gentropy/method/test_colocalisation_method.py @@ -29,6 +29,7 @@ def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: { "leftStudyLocusId": 1, "rightStudyLocusId": 2, + "rightStudyType": "eqtl", "chromosome": "1", "tagVariantId": "snp", "statistics": {"left_logBF": 10.3, "right_logBF": 10.5}, @@ -52,6 +53,7 @@ def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: { "leftStudyLocusId": 1, "rightStudyLocusId": 2, + "rightStudyType": "eqtl", "chromosome": "1", "tagVariantId": "snp1", "statistics": {"left_logBF": 10.3, "right_logBF": 10.5}, @@ -59,6 +61,7 @@ def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: { "leftStudyLocusId": 1, "rightStudyLocusId": 2, + "rightStudyType": "eqtl", "chromosome": "1", "tagVariantId": "snp2", "statistics": {"left_logBF": 10.3, "right_logBF": 10.5}, @@ -119,6 +122,7 @@ def test_coloc_no_logbf( { "leftStudyLocusId": 1, "rightStudyLocusId": 2, + "rightStudyType": "eqtl", "chromosome": "1", "tagVariantId": "snp", "statistics": { @@ -131,6 +135,7 @@ def test_coloc_no_logbf( [ StructField("leftStudyLocusId", LongType(), False), StructField("rightStudyLocusId", LongType(), False), + StructField("rightStudyType", StringType(), False), StructField("chromosome", StringType(), False), StructField("tagVariantId", StringType(), False), StructField(