Skip to content

Commit

Permalink
feat: functionality added to StudyLocus.find_overlaps() for finding w…
Browse files Browse the repository at this point in the history
…ithin-study overlaps (#587)

* feat: functionality added to StudyLocus.find_overlaps() for finding within-study overlaps

* feat: removal of secondary credible sets at the same region as overlaps

* fix: defining the join condition to make the code tidier
  • Loading branch information
Daniel-Considine authored Apr 26, 2024
1 parent 05d21bc commit a88f16c
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 15 deletions.
50 changes: 38 additions & 12 deletions src/gentropy/dataset/study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,31 +82,52 @@ class StudyLocus(Dataset):
"""

@staticmethod
def _overlapping_peaks(credset_to_overlap: DataFrame) -> DataFrame:
def _overlapping_peaks(
credset_to_overlap: DataFrame, intra_study_overlap: bool = False
) -> DataFrame:
"""Calculate overlapping signals (study-locus) between GWAS-GWAS and GWAS-Molecular trait.
Args:
credset_to_overlap (DataFrame): DataFrame containing at least `studyLocusId`, `studyType`, `chromosome` and `tagVariantId` columns.
intra_study_overlap (bool): When True, finds intra-study overlaps for credible set deduplication. Default is False.
Returns:
DataFrame: containing `leftStudyLocusId`, `rightStudyLocusId` and `chromosome` columns.
"""
# Reduce columns to the minimum to reduce the size of the dataframe
credset_to_overlap = credset_to_overlap.select(
"studyLocusId", "studyType", "chromosome", "tagVariantId"
"studyLocusId",
"studyId",
"studyType",
"chromosome",
"region",
"tagVariantId",
)
# Define join condition - if intra_study_overlap is True, finds overlaps within the same study. Otherwise finds gwas vs everything overlaps for coloc.
join_condition = (
[
f.col("left.studyId") == f.col("right.studyId"),
f.col("left.chromosome") == f.col("right.chromosome"),
f.col("left.tagVariantId") == f.col("right.tagVariantId"),
f.col("left.studyLocusId") > f.col("right.studyLocusId"),
f.col("left.region") != f.col("right.region"),
]
if intra_study_overlap
else [
f.col("left.chromosome") == f.col("right.chromosome"),
f.col("left.tagVariantId") == f.col("right.tagVariantId"),
(f.col("right.studyType") != "gwas")
| (f.col("left.studyLocusId") > f.col("right.studyLocusId")),
f.col("left.studyType") == f.lit("gwas"),
]
)

return (
credset_to_overlap.alias("left")
.filter(f.col("studyType") == "gwas")
# Self join with complex condition. Left it's all gwas and right can be gwas or molecular trait
# Self join with complex condition.
.join(
credset_to_overlap.alias("right"),
on=[
f.col("left.chromosome") == f.col("right.chromosome"),
f.col("left.tagVariantId") == f.col("right.tagVariantId"),
(f.col("right.studyType") != "gwas")
| (f.col("left.studyLocusId") > f.col("right.studyLocusId")),
],
on=join_condition,
how="inner",
)
.select(
Expand Down Expand Up @@ -305,14 +326,17 @@ def filter_credible_set(
)
return self

def find_overlaps(self: StudyLocus, study_index: StudyIndex) -> StudyLocusOverlap:
def find_overlaps(
self: StudyLocus, study_index: StudyIndex, intra_study_overlap: bool = False
) -> StudyLocusOverlap:
"""Calculate overlapping study-locus.
Find overlapping study-locus that share at least one tagging variant. All GWAS-GWAS and all GWAS-Molecular traits are computed with the Molecular traits always
appearing on the right side.
Args:
study_index (StudyIndex): Study index to resolve study types.
intra_study_overlap (bool): If True, finds intra-study overlaps for credible set deduplication. Default is False.
Returns:
StudyLocusOverlap: Pairs of overlapping study-locus with aligned tags.
Expand All @@ -322,8 +346,10 @@ def find_overlaps(self: StudyLocus, study_index: StudyIndex) -> StudyLocusOverla
.withColumn("locus", f.explode("locus"))
.select(
"studyLocusId",
"studyId",
"studyType",
"chromosome",
"region",
f.col("locus.variantId").alias("tagVariantId"),
f.col("locus.logBF").alias("logBF"),
f.col("locus.posteriorProbability").alias("posteriorProbability"),
Expand All @@ -335,7 +361,7 @@ def find_overlaps(self: StudyLocus, study_index: StudyIndex) -> StudyLocusOverla
)

# overlapping study-locus
peak_overlaps = self._overlapping_peaks(loci_to_overlap)
peak_overlaps = self._overlapping_peaks(loci_to_overlap, intra_study_overlap)

# study-locus overlap by aligning overlapping variants
return self._align_overlapping_tags(loci_to_overlap, peak_overlaps)
Expand Down
49 changes: 46 additions & 3 deletions tests/gentropy/dataset/test_study_locus_overlaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,46 +30,89 @@ def test_study_locus_overlap_from_associations(


@pytest.mark.parametrize(
("observed", "expected"),
("observed", "intrastudy", "expected"),
[
(
# observed - input DataFrame representing gwas and nongwas data to find overlapping signals
[
{
"studyLocusId": 1,
"studyId": "A",
"studyType": "gwas",
"chromosome": "1",
"tagVariantId": "A",
},
{
"studyLocusId": 2,
"studyId": "B",
"studyType": "eqtl",
"chromosome": "1",
"tagVariantId": "A",
},
{
"studyLocusId": 3,
"studyId": "C",
"studyType": "gwas",
"chromosome": "1",
"tagVariantId": "B",
},
],
# intrastudy - bool of whether or not to use inter-study or intra-study logic
False,
# expected - output DataFrame with overlapping signals
[
{"leftStudyLocusId": 1, "rightStudyLocusId": 2, "chromosome": "1"},
],
),
(
# observed - input DataFrame representing intra-study data to find overlapping signals in the same study
[
{
"studyLocusId": 1,
"studyId": "A",
"studyType": "gwas",
"chromosome": "1",
"region": "X",
"tagVariantId": "A",
},
{
"studyLocusId": 2,
"studyId": "A",
"studyType": "gwas",
"chromosome": "1",
"region": "Y",
"tagVariantId": "A",
},
{
"studyLocusId": 3,
"studyId": "B",
"studyType": "gwas",
"chromosome": "1",
"region": "X",
"tagVariantId": "A",
},
],
# intrastudy - bool of whether or not to use inter-study or intra-study logic
True,
# expected - output DataFrame with overlapping signals
[{"leftStudyLocusId": 2, "rightStudyLocusId": 1, "chromosome": "1"}],
),
],
)
def test_overlapping_peaks(
spark: SparkSession, observed: list[dict[str, Any]], expected: list[dict[str, Any]]
spark: SparkSession,
observed: list[dict[str, Any]],
intrastudy: bool,
expected: list[dict[str, Any]],
) -> None:
"""Test overlapping signals between GWAS-GWAS and GWAS-Molecular trait to make sure that mQTLs are always on the right."""
mock_schema = t.StructType(
[
t.StructField("studyLocusId", t.LongType()),
t.StructField("studyId", t.StringType()),
t.StructField("studyType", t.StringType()),
t.StructField("chromosome", t.StringType()),
t.StructField("region", t.StringType()),
t.StructField("tagVariantId", t.StringType()),
]
)
Expand All @@ -81,6 +124,6 @@ def test_overlapping_peaks(
]
)
observed_df = spark.createDataFrame(observed, mock_schema)
result_df = StudyLocus._overlapping_peaks(observed_df)
result_df = StudyLocus._overlapping_peaks(observed_df, intrastudy)
expected_df = spark.createDataFrame(expected, expected_schema)
assert result_df.collect() == expected_df.collect()

0 comments on commit a88f16c

Please sign in to comment.