diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index ff378e170..5373be108 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -191,37 +191,43 @@ class StudyLocusFactory(StudyLocus): """Feature extraction in study locus.""" @staticmethod - def _get_tss_distance_features( - credible_set: StudyLocus, distances: V2G - ) -> L2GFeature: - """Joins StudyLocus with the V2G to extract the minimum distance to a gene TSS of all variants in a StudyLocus credible set. + def _get_tss_distance_features(credible_set: StudyLocus, v2g: V2G) -> L2GFeature: + """Joins StudyLocus with the V2G to extract a score that is based on the distance to a gene TSS of any variant weighted by its posterior probability in a credible set. Args: credible_set (StudyLocus): Credible set dataset - distances (V2G): Dataframe containing the distances of all variants to all genes TSS within a region + v2g (V2G): Dataframe containing the distances of all variants to all genes TSS within a region Returns: - L2GFeature: Stores the features with the minimum distance among all variants in the credible set and a gene TSS. + L2GFeature: Stores the features with the score of weighting the distance to the TSS by the posterior probability of the variant """ wide_df = ( credible_set.filter_credible_set(CredibleInterval.IS95) - .df.select( + .df.withColumn("variantInLocus", f.explode_outer("locus")) + .select( "studyLocusId", "variantId", - f.explode("locus.variantId").alias("tagVariantId"), + f.col("variantInLocus.variantId").alias("variantInLocusId"), + f.col("variantInLocus.posteriorProbability").alias( + "variantInLocusPosteriorProbability" + ), ) .join( - distances.df.selectExpr( - "variantId as tagVariantId", "geneId", "distance" + v2g.df.filter(f.col("datasourceId") == "canonical_tss").selectExpr( + "variantId as variantInLocusId", "geneId", "score" ), - on="tagVariantId", + on="variantInLocusId", how="inner", ) + .withColumn( + "weightedScore", + f.col("score") * f.col("variantInLocusPosteriorProbability"), + ) .groupBy("studyLocusId", "geneId") .agg( - f.min("distance").alias("distanceTssMinimum"), - f.mean("distance").alias("distanceTssMean"), + f.min("weightedScore").alias("distanceTssMinimum"), + f.mean("weightedScore").alias("distanceTssMean"), ) )