diff --git a/src/gentropy/dataset/l2g_feature_matrix.py b/src/gentropy/dataset/l2g_feature_matrix.py index 948477600..e5be1a019 100644 --- a/src/gentropy/dataset/l2g_feature_matrix.py +++ b/src/gentropy/dataset/l2g_feature_matrix.py @@ -5,8 +5,6 @@ from functools import reduce from typing import TYPE_CHECKING, Type -from pyspark.sql.functions import col - from gentropy.common.schemas import parse_spark_schema from gentropy.common.spark_helpers import convert_from_long_to_wide from gentropy.dataset.dataset import Dataset @@ -62,22 +60,13 @@ def generate_features( Raises: ValueError: If the feature matrix is empty """ - coloc_methods = ( - colocalisation.df.select("colocalisationMethod") - .distinct() - .toPandas()["colocalisationMethod"] - .tolist() - ) if features_dfs := [ # Extract features ColocalisationFactory._get_max_coloc_per_credible_set( + colocalisation, credible_set, study_index, - colocalisation.filter(col("colocalisationMethod") == method), - method, - ).df - for method in coloc_methods - ] + [ + ).df, StudyLocusFactory._get_tss_distance_features(credible_set, variant_gene).df, StudyLocusFactory._get_vep_features(credible_set, variant_gene).df, ]: diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index 4c6d05a81..3b9b5b278 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -26,16 +26,16 @@ class ColocalisationFactory: @staticmethod def _get_max_coloc_per_credible_set( + colocalisation: Colocalisation, credible_set: StudyLocus, studies: StudyIndex, - colocalisation: Colocalisation, ) -> L2GFeature: """Get the maximum colocalisation posterior probability for each pair of overlapping study-locus per type of colocalisation method and QTL type. Args: + colocalisation (Colocalisation): Colocalisation dataset credible_set (StudyLocus): Study locus dataset studies (StudyIndex): Study index dataset - colocalisation (Colocalisation): Colocalisation dataset Returns: L2GFeature: Stores the features with the max coloc probabilities for each pair of study-locus @@ -62,7 +62,9 @@ def _get_max_coloc_per_credible_set( credible_set.df.join( studies.df.select("studyId", "studyType", "geneId"), "studyId" ).selectExpr( - "studyLocusId as rightStudyLocusId", "studyType as right_studyType" + "studyLocusId as rightStudyLocusId", + "studyType as right_studyType", + "geneId", ), on="rightStudyLocusId", how="inner", @@ -135,9 +137,9 @@ def _get_max_coloc_per_credible_set( f.lit("Coloc"), f.col("colocalisationMetric"), f.lit("Maximum"), - f.col("score_type"), + f.regexp_replace(f.col("score_type"), "Local", ""), ).alias("featureName"), - f.col("max_score").alias("featureValue"), + f.col("max_score").cast("float").alias("featureValue"), ) ), _schema=L2GFeature.get_schema(), diff --git a/tests/gentropy/method/test_locus_to_gene.py b/tests/gentropy/method/test_locus_to_gene.py index 3976406c9..898252f9f 100644 --- a/tests/gentropy/method/test_locus_to_gene.py +++ b/tests/gentropy/method/test_locus_to_gene.py @@ -80,30 +80,21 @@ def test_train( class TestColocalisationFactory: """Test the ColocalisationFactory methods.""" - @pytest.mark.parametrize( - "colocalisation_method", - [ - "COLOC", - "eCAVIAR", - ], - ) def test_get_max_coloc_per_credible_set( self: TestColocalisationFactory, mock_study_locus: StudyLocus, mock_study_index: StudyIndex, mock_colocalisation: Colocalisation, - colocalisation_method: str, ) -> None: """Test the function that extracts the maximum log likelihood ratio for each pair of overlapping study-locus returns the right data type.""" coloc_features = ColocalisationFactory._get_max_coloc_per_credible_set( + mock_colocalisation, mock_study_locus, mock_study_index, - mock_colocalisation, - colocalisation_method, ) assert isinstance( coloc_features, L2GFeature - ), "Unexpected model type returned from _get_max_coloc_per_credible_set" + ), "Unexpected type returned from _get_max_coloc_per_credible_set" def test_get_max_coloc_per_credible_set_semantic( self: TestColocalisationFactory, @@ -169,8 +160,10 @@ def test_get_max_coloc_per_credible_set_semantic( "colocalisationMethod": "eCAVIAR", "numberColocalisingVariants": 1, "clpp": 0.81, # 0.9*0.9 + "log2h4h3": None, } - ] + ], + schema=Colocalisation.get_schema(), ), _schema=Colocalisation.get_schema(), ) @@ -183,10 +176,9 @@ def test_get_max_coloc_per_credible_set_semantic( ) # Test coloc_features = ColocalisationFactory._get_max_coloc_per_credible_set( + coloc, credset, studies, - coloc, - "eCAVIAR", ) assert coloc_features.df.collect() == expected_coloc_features_df.collect()