diff --git a/src/gentropy/dataset/dataset.py b/src/gentropy/dataset/dataset.py index 856357afe..ada3d0580 100644 --- a/src/gentropy/dataset/dataset.py +++ b/src/gentropy/dataset/dataset.py @@ -1,10 +1,14 @@ """Dataset class for gentropy.""" + from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass +from functools import reduce from typing import TYPE_CHECKING, Any +import pyspark.sql.functions as f +from pyspark.sql.types import DoubleType from typing_extensions import Self from gentropy.common.schemas import flatten_schema @@ -164,6 +168,29 @@ def validate_schema(self: Dataset) -> None: f"The following fields present differences in their datatypes: {fields_with_different_observed_datatype}." ) + def drop_infinity_values(self: Self, *cols: str) -> Self: + """Drop infinity values from Double typed column. + + Infinity type reference - https://spark.apache.org/docs/latest/sql-ref-datatypes.html#floating-point-special-values + The implementation comes from https://stackoverflow.com/questions/34432998/how-to-replace-infinity-in-pyspark-dataframe + + Args: + *cols (str): names of the columns to check for infinite values, these should be of DoubleType only! + + Returns: + Self: Dataset after removing infinite values + """ + if len(cols) == 0: + return self + inf_strings = ("Inf", "+Inf", "-Inf", "Infinity", "+Infinity", "-Infinity") + inf_values = [f.lit(v).cast(DoubleType()) for v in inf_strings] + conditions = [f.col(c).isin(inf_values) for c in cols] + # reduce individual filter expressions with or statement + # to col("beta").isin([lit(Inf)]) | col("beta").isin([lit(Inf)])... + condition = reduce(lambda a, b: a | b, conditions) + self.df = self._df.filter(~condition) + return self + def persist(self: Self) -> Self: """Persist in memory the DataFrame included in the Dataset. diff --git a/src/gentropy/dataset/summary_statistics.py b/src/gentropy/dataset/summary_statistics.py index 6244d5879..71aa35c8a 100644 --- a/src/gentropy/dataset/summary_statistics.py +++ b/src/gentropy/dataset/summary_statistics.py @@ -1,4 +1,5 @@ """Summary satistics dataset.""" + from __future__ import annotations from dataclasses import dataclass @@ -108,9 +109,12 @@ def sanity_filter(self: SummaryStatistics) -> SummaryStatistics: """The function filters the summary statistics by sanity filters. The function filters the summary statistics by the following filters: - - The p-value should not be eqaul 1. - - The beta and se should not be equal 0. + - The p-value should be less than 1. + - The pValueMantissa should be greater than 0. + - The beta should not be equal 0. - The p-value, beta and se should not be NaN. + - The se should be positive. + - The beta and se should not be infinite. Returns: SummaryStatistics: The filtered summary statistics. @@ -119,13 +123,15 @@ def sanity_filter(self: SummaryStatistics) -> SummaryStatistics: gwas_df = gwas_df.dropna( subset=["beta", "standardError", "pValueMantissa", "pValueExponent"] ) - - gwas_df = gwas_df.filter((f.col("beta") != 0) & (f.col("standardError") != 0)) + gwas_df = gwas_df.filter((f.col("beta") != 0) & (f.col("standardError") > 0)) gwas_df = gwas_df.filter( - f.col("pValueMantissa") * 10 ** f.col("pValueExponent") != 1 + (f.col("pValueMantissa") * 10 ** f.col("pValueExponent") < 1) + & (f.col("pValueMantissa") > 0) ) - - return SummaryStatistics( + cols = ["beta", "standardError"] + summary_stats = SummaryStatistics( _df=gwas_df, _schema=SummaryStatistics.get_schema(), - ) + ).drop_infinity_values(*cols) + + return summary_stats diff --git a/tests/gentropy/dataset/test_dataset.py b/tests/gentropy/dataset/test_dataset.py index 88f44ed5b..bddbb4f6a 100644 --- a/tests/gentropy/dataset/test_dataset.py +++ b/tests/gentropy/dataset/test_dataset.py @@ -2,12 +2,18 @@ from __future__ import annotations +import numpy as np import pyspark.sql.functions as f import pytest from gentropy.dataset.dataset import Dataset from gentropy.dataset.study_index import StudyIndex from pyspark.sql import SparkSession -from pyspark.sql.types import IntegerType, StructField, StructType +from pyspark.sql.types import ( + DoubleType, + IntegerType, + StructField, + StructType, +) class MockDataset(Dataset): @@ -57,3 +63,18 @@ def test_dataset_filter(mock_study_index: StudyIndex) -> None: filtered.df.select("studyType").distinct().toPandas()["studyType"].to_list()[0] == expected_filter_value ), "Filtering failed." + + +def test_dataset_drop_infinity_values() -> None: + """drop_infinity_values method shoud remove inf value from standardError field.""" + spark = SparkSession.getActiveSession() + data = [np.Infinity, -np.Infinity, np.inf, -np.inf, np.Inf, -np.Inf, 5.1] + rows = [(v,) for v in data] + schema = StructType([StructField("field", DoubleType())]) + input_df = spark.createDataFrame(rows, schema=schema) + assert input_df.count() == 7 + # run without specifying *cols results in no filtering + ds = MockDataset(_df=input_df, _schema=schema) + assert ds.drop_infinity_values().df.count() == 7 + # otherwise drop all columns + assert ds.drop_infinity_values("field").df.count() == 1 diff --git a/tests/gentropy/method/test_qc_of_sumstats.py b/tests/gentropy/method/test_qc_of_sumstats.py index 8480fce8d..6c2d23f65 100644 --- a/tests/gentropy/method/test_qc_of_sumstats.py +++ b/tests/gentropy/method/test_qc_of_sumstats.py @@ -5,6 +5,8 @@ import numpy as np import pandas as pd import pyspark.sql.functions as f +import pytest +from gentropy.common.session import Session from gentropy.dataset.summary_statistics import SummaryStatistics from gentropy.method.sumstat_quality_controls import SummaryStatisticsQC from pyspark.sql.functions import rand, when @@ -61,3 +63,45 @@ def test_several_studyid( ) QC = QC.toPandas() assert QC.shape == (2, 8) + + +def test_sanity_filter_remove_inf_values( + session: Session, +) -> None: + """Sanity filter remove inf value from standardError field.""" + data = [ + ( + "GCST012234", + "10_73856419_C_A", + 10, + 73856419, + np.Infinity, + 1, + 3.1324, + -650, + None, + 0.4671, + ), + ( + "GCST012234", + "14_98074714_G_C", + 14, + 98074714, + 6.697, + 2, + 5.4275, + -2890, + None, + 0.4671, + ), + ] + input_df = session.spark.createDataFrame( + data=data, schema=SummaryStatistics.get_schema() + ) + summary_stats = SummaryStatistics( + _df=input_df, _schema=SummaryStatistics.get_schema() + ) + stats_after_filter = summary_stats.sanity_filter().df.collect() + assert input_df.count() == 2 + assert len(stats_after_filter) == 1 + assert stats_after_filter[0]["beta"] - 6.697 == pytest.approx(0)