diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index b43125017..4c6d05a81 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -7,7 +7,6 @@ import pyspark.sql.functions as f from gentropy.common.spark_helpers import ( - convert_from_long_to_wide, convert_from_wide_to_long, get_record_with_maximum_value, ) @@ -121,12 +120,15 @@ def _get_max_coloc_per_credible_set( ) ).drop("tmp_nbh_max_score", "local_max_score") - wide_df = convert_from_long_to_wide( - df=( + return L2GFeature( + _df=( + # Combine local and neighborhood metrics local_max.unionByName( neighbourhood_max, allowMissingColumns=True - ).withColumn( - "featureName", + ).select( + "studyLocusId", + "geneId", + # Feature name is a concatenation of the QTL type, colocalisation metric and if it's local or in the vicinity f.concat_ws( "", f.col("right_studyType"), @@ -134,31 +136,10 @@ def _get_max_coloc_per_credible_set( f.col("colocalisationMetric"), f.lit("Maximum"), f.col("score_type"), - ), + ).alias("featureName"), + f.col("max_score").alias("featureValue"), ) ), - id_vars=["studyLocusId", "geneId"], - var_name="featureName", - value_name="max_score", - ) - - return L2GFeature( - _df=convert_from_wide_to_long( - wide_df.groupBy("studyLocusId", "geneId").agg( - *( - f.first(f.col(c), ignorenulls=True).alias(c) - for c in wide_df.columns - if c - not in [ - "studyLocusId", - "geneId", - ] - ) - ), - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", - ).filter(f.col("featureValue").isNotNull()), _schema=L2GFeature.get_schema(), )