Skip to content

Commit

Permalink
fix(SummaryStatistics): fix in sanity_filter (#623)
Browse files Browse the repository at this point in the history
* fix(SummaryStatistics): fix in sanity_filter

* fix: adding prune of inf values

* fix(dataset): removal of inf values from beta and stderr

* fix: fix in test and sanity filter

---------

Co-authored-by: Szymon Szyszkowski <[email protected]>
Co-authored-by: Szymon Szyszkowski <[email protected]>
  • Loading branch information
3 people committed May 30, 2024
1 parent 48cf2a8 commit b60a19f
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 9 deletions.
27 changes: 27 additions & 0 deletions src/gentropy/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Dataset class for gentropy."""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import reduce
from typing import TYPE_CHECKING, Any

import pyspark.sql.functions as f
from pyspark.sql.types import DoubleType
from typing_extensions import Self

from gentropy.common.schemas import flatten_schema
Expand Down Expand Up @@ -164,6 +168,29 @@ def validate_schema(self: Dataset) -> None:
f"The following fields present differences in their datatypes: {fields_with_different_observed_datatype}."
)

def drop_infinity_values(self: Self, *cols: str) -> Self:
"""Drop infinity values from Double typed column.
Infinity type reference - https://spark.apache.org/docs/latest/sql-ref-datatypes.html#floating-point-special-values
The implementation comes from https://stackoverflow.com/questions/34432998/how-to-replace-infinity-in-pyspark-dataframe
Args:
*cols (str): names of the columns to check for infinite values, these should be of DoubleType only!
Returns:
Self: Dataset after removing infinite values
"""
if len(cols) == 0:
return self
inf_strings = ("Inf", "+Inf", "-Inf", "Infinity", "+Infinity", "-Infinity")
inf_values = [f.lit(v).cast(DoubleType()) for v in inf_strings]
conditions = [f.col(c).isin(inf_values) for c in cols]
# reduce individual filter expressions with or statement
# to col("beta").isin([lit(Inf)]) | col("beta").isin([lit(Inf)])...
condition = reduce(lambda a, b: a | b, conditions)
self.df = self._df.filter(~condition)
return self

def persist(self: Self) -> Self:
"""Persist in memory the DataFrame included in the Dataset.
Expand Down
22 changes: 14 additions & 8 deletions src/gentropy/dataset/summary_statistics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Summary satistics dataset."""

from __future__ import annotations

from dataclasses import dataclass
Expand Down Expand Up @@ -108,9 +109,12 @@ def sanity_filter(self: SummaryStatistics) -> SummaryStatistics:
"""The function filters the summary statistics by sanity filters.
The function filters the summary statistics by the following filters:
- The p-value should not be eqaul 1.
- The beta and se should not be equal 0.
- The p-value should be less than 1.
- The pValueMantissa should be greater than 0.
- The beta should not be equal 0.
- The p-value, beta and se should not be NaN.
- The se should be positive.
- The beta and se should not be infinite.
Returns:
SummaryStatistics: The filtered summary statistics.
Expand All @@ -119,13 +123,15 @@ def sanity_filter(self: SummaryStatistics) -> SummaryStatistics:
gwas_df = gwas_df.dropna(
subset=["beta", "standardError", "pValueMantissa", "pValueExponent"]
)

gwas_df = gwas_df.filter((f.col("beta") != 0) & (f.col("standardError") != 0))
gwas_df = gwas_df.filter((f.col("beta") != 0) & (f.col("standardError") > 0))
gwas_df = gwas_df.filter(
f.col("pValueMantissa") * 10 ** f.col("pValueExponent") != 1
(f.col("pValueMantissa") * 10 ** f.col("pValueExponent") < 1)
& (f.col("pValueMantissa") > 0)
)

return SummaryStatistics(
cols = ["beta", "standardError"]
summary_stats = SummaryStatistics(
_df=gwas_df,
_schema=SummaryStatistics.get_schema(),
)
).drop_infinity_values(*cols)

return summary_stats
23 changes: 22 additions & 1 deletion tests/gentropy/dataset/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@

from __future__ import annotations

import numpy as np
import pyspark.sql.functions as f
import pytest
from gentropy.dataset.dataset import Dataset
from gentropy.dataset.study_index import StudyIndex
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StructField, StructType
from pyspark.sql.types import (
DoubleType,
IntegerType,
StructField,
StructType,
)


class MockDataset(Dataset):
Expand Down Expand Up @@ -57,3 +63,18 @@ def test_dataset_filter(mock_study_index: StudyIndex) -> None:
filtered.df.select("studyType").distinct().toPandas()["studyType"].to_list()[0]
== expected_filter_value
), "Filtering failed."


def test_dataset_drop_infinity_values() -> None:
"""drop_infinity_values method shoud remove inf value from standardError field."""
spark = SparkSession.getActiveSession()
data = [np.Infinity, -np.Infinity, np.inf, -np.inf, np.Inf, -np.Inf, 5.1]
rows = [(v,) for v in data]
schema = StructType([StructField("field", DoubleType())])
input_df = spark.createDataFrame(rows, schema=schema)
assert input_df.count() == 7
# run without specifying *cols results in no filtering
ds = MockDataset(_df=input_df, _schema=schema)
assert ds.drop_infinity_values().df.count() == 7
# otherwise drop all columns
assert ds.drop_infinity_values("field").df.count() == 1
44 changes: 44 additions & 0 deletions tests/gentropy/method/test_qc_of_sumstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import pandas as pd
import pyspark.sql.functions as f
import pytest
from gentropy.common.session import Session
from gentropy.dataset.summary_statistics import SummaryStatistics
from gentropy.method.sumstat_quality_controls import SummaryStatisticsQC
from pyspark.sql.functions import rand, when
Expand Down Expand Up @@ -61,3 +63,45 @@ def test_several_studyid(
)
QC = QC.toPandas()
assert QC.shape == (2, 8)


def test_sanity_filter_remove_inf_values(
session: Session,
) -> None:
"""Sanity filter remove inf value from standardError field."""
data = [
(
"GCST012234",
"10_73856419_C_A",
10,
73856419,
np.Infinity,
1,
3.1324,
-650,
None,
0.4671,
),
(
"GCST012234",
"14_98074714_G_C",
14,
98074714,
6.697,
2,
5.4275,
-2890,
None,
0.4671,
),
]
input_df = session.spark.createDataFrame(
data=data, schema=SummaryStatistics.get_schema()
)
summary_stats = SummaryStatistics(
_df=input_df, _schema=SummaryStatistics.get_schema()
)
stats_after_filter = summary_stats.sanity_filter().df.collect()
assert input_df.count() == 2
assert len(stats_after_filter) == 1
assert stats_after_filter[0]["beta"] - 6.697 == pytest.approx(0)

0 comments on commit b60a19f

Please sign in to comment.