Skip to content

Commit

Permalink
chore: adapt tests to changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ireneisdoomed committed Mar 20, 2024
1 parent 75bafd5 commit c1f49e5
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 32 deletions.
15 changes: 2 additions & 13 deletions src/gentropy/dataset/l2g_feature_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from functools import reduce
from typing import TYPE_CHECKING, Type

from pyspark.sql.functions import col

from gentropy.common.schemas import parse_spark_schema
from gentropy.common.spark_helpers import convert_from_long_to_wide
from gentropy.dataset.dataset import Dataset
Expand Down Expand Up @@ -62,22 +60,13 @@ def generate_features(
Raises:
ValueError: If the feature matrix is empty
"""
coloc_methods = (
colocalisation.df.select("colocalisationMethod")
.distinct()
.toPandas()["colocalisationMethod"]
.tolist()
)
if features_dfs := [
# Extract features
ColocalisationFactory._get_max_coloc_per_credible_set(
colocalisation,
credible_set,
study_index,
colocalisation.filter(col("colocalisationMethod") == method),
method,
).df
for method in coloc_methods
] + [
).df,
StudyLocusFactory._get_tss_distance_features(credible_set, variant_gene).df,
StudyLocusFactory._get_vep_features(credible_set, variant_gene).df,
]:
Expand Down
12 changes: 7 additions & 5 deletions src/gentropy/method/l2g/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ class ColocalisationFactory:

@staticmethod
def _get_max_coloc_per_credible_set(
colocalisation: Colocalisation,
credible_set: StudyLocus,
studies: StudyIndex,
colocalisation: Colocalisation,
) -> L2GFeature:
"""Get the maximum colocalisation posterior probability for each pair of overlapping study-locus per type of colocalisation method and QTL type.
Args:
colocalisation (Colocalisation): Colocalisation dataset
credible_set (StudyLocus): Study locus dataset
studies (StudyIndex): Study index dataset
colocalisation (Colocalisation): Colocalisation dataset
Returns:
L2GFeature: Stores the features with the max coloc probabilities for each pair of study-locus
Expand All @@ -62,7 +62,9 @@ def _get_max_coloc_per_credible_set(
credible_set.df.join(
studies.df.select("studyId", "studyType", "geneId"), "studyId"
).selectExpr(
"studyLocusId as rightStudyLocusId", "studyType as right_studyType"
"studyLocusId as rightStudyLocusId",
"studyType as right_studyType",
"geneId",
),
on="rightStudyLocusId",
how="inner",
Expand Down Expand Up @@ -135,9 +137,9 @@ def _get_max_coloc_per_credible_set(
f.lit("Coloc"),
f.col("colocalisationMetric"),
f.lit("Maximum"),
f.col("score_type"),
f.regexp_replace(f.col("score_type"), "Local", ""),
).alias("featureName"),
f.col("max_score").alias("featureValue"),
f.col("max_score").cast("float").alias("featureValue"),
)
),
_schema=L2GFeature.get_schema(),
Expand Down
20 changes: 6 additions & 14 deletions tests/gentropy/method/test_locus_to_gene.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,30 +80,21 @@ def test_train(
class TestColocalisationFactory:
"""Test the ColocalisationFactory methods."""

@pytest.mark.parametrize(
"colocalisation_method",
[
"COLOC",
"eCAVIAR",
],
)
def test_get_max_coloc_per_credible_set(
self: TestColocalisationFactory,
mock_study_locus: StudyLocus,
mock_study_index: StudyIndex,
mock_colocalisation: Colocalisation,
colocalisation_method: str,
) -> None:
"""Test the function that extracts the maximum log likelihood ratio for each pair of overlapping study-locus returns the right data type."""
coloc_features = ColocalisationFactory._get_max_coloc_per_credible_set(
mock_colocalisation,
mock_study_locus,
mock_study_index,
mock_colocalisation,
colocalisation_method,
)
assert isinstance(
coloc_features, L2GFeature
), "Unexpected model type returned from _get_max_coloc_per_credible_set"
), "Unexpected type returned from _get_max_coloc_per_credible_set"

def test_get_max_coloc_per_credible_set_semantic(
self: TestColocalisationFactory,
Expand Down Expand Up @@ -169,8 +160,10 @@ def test_get_max_coloc_per_credible_set_semantic(
"colocalisationMethod": "eCAVIAR",
"numberColocalisingVariants": 1,
"clpp": 0.81, # 0.9*0.9
"log2h4h3": None,
}
]
],
schema=Colocalisation.get_schema(),
),
_schema=Colocalisation.get_schema(),
)
Expand All @@ -183,10 +176,9 @@ def test_get_max_coloc_per_credible_set_semantic(
)
# Test
coloc_features = ColocalisationFactory._get_max_coloc_per_credible_set(
coloc,
credset,
studies,
coloc,
"eCAVIAR",
)
assert coloc_features.df.collect() == expected_coloc_features_df.collect()

Expand Down

0 comments on commit c1f49e5

Please sign in to comment.