From 2010fb654e5cae241028e9614b79865e871034e1 Mon Sep 17 00:00:00 2001 From: Yakov Date: Tue, 24 Sep 2024 17:10:23 +0100 Subject: [PATCH] fix: remove n_eff check from qc_step (#785) --- src/gentropy/config.py | 1 + .../method/sumstat_quality_controls.py | 21 ++++++++----------- src/gentropy/sumstat_qc_step.py | 4 +++- tests/gentropy/method/test_qc_of_sumstats.py | 16 +++++--------- 4 files changed, 18 insertions(+), 24 deletions(-) diff --git a/src/gentropy/config.py b/src/gentropy/config.py index 32edc9a4a..86bfc7afe 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -455,6 +455,7 @@ class GWASQCStep(StepConfig): gwas_path: str = MISSING output_path: str = MISSING studyid: str = MISSING + pval_threshold: float = MISSING _target_: str = "gentropy.sumstat_qc_step.SummaryStatisticsQCStep" diff --git a/src/gentropy/method/sumstat_quality_controls.py b/src/gentropy/method/sumstat_quality_controls.py index 2858f4813..1647851de 100644 --- a/src/gentropy/method/sumstat_quality_controls.py +++ b/src/gentropy/method/sumstat_quality_controls.py @@ -1,4 +1,5 @@ """Summary statistics qulity control methods.""" + from __future__ import annotations import numpy as np @@ -225,13 +226,13 @@ def gc_lambda_check( @staticmethod def number_of_snps( - gwas_for_qc: SummaryStatistics, pval_threhod: float = 5e-8 + gwas_for_qc: SummaryStatistics, pval_threshold: float = 5e-8 ) -> DataFrame: """The function caluates number of SNPs and number of SNPs with p-value less than 5e-8. Args: gwas_for_qc (SummaryStatistics): The instance of the SummaryStatistics class. - pval_threhod (float): The threshold for the p-value. + pval_threshold (float): The threshold for the p-value. Returns: DataFrame: PySpark DataFrame with the number of SNPs and number of SNPs with p-value less than threshold. @@ -243,7 +244,7 @@ def number_of_snps( f.sum( ( f.log10(f.col("pValueMantissa")) + f.col("pValueExponent") - <= np.log10(pval_threhod) + <= np.log10(pval_threshold) ).cast("int") ).alias("n_variants_sig"), ) @@ -254,30 +255,26 @@ def number_of_snps( def get_quality_control_metrics( gwas: SummaryStatistics, limit: int = 100_000_000, - min_count: int = 100_000, - n_total: int = 100_000, + pval_threshold: float = 5e-8, ) -> DataFrame: """The function calculates the quality control metrics for the summary statistics. Args: gwas (SummaryStatistics): The instance of the SummaryStatistics class. limit (int): The limit for the number of variants to be used for the estimation. - min_count (int): The minimum number of variants to be used for the estimation. - n_total (int): The total sample size. + pval_threshold (float): The threshold for the p-value. Returns: DataFrame: PySpark DataFrame with the quality control metrics for the summary statistics. """ qc1 = SummaryStatisticsQC.sumstat_qc_beta_check(gwas_for_qc=gwas) qc2 = SummaryStatisticsQC.sumstat_qc_pz_check(gwas_for_qc=gwas, limit=limit) - qc3 = SummaryStatisticsQC.sumstat_n_eff_check( - gwas_for_qc=gwas, n_total=n_total, limit=limit, min_count=min_count - ) qc4 = SummaryStatisticsQC.gc_lambda_check(gwas_for_qc=gwas, limit=limit) - qc5 = SummaryStatisticsQC.number_of_snps(gwas_for_qc=gwas) + qc5 = SummaryStatisticsQC.number_of_snps( + gwas_for_qc=gwas, pval_threshold=pval_threshold + ) df = ( qc1.join(qc2, on="studyId", how="outer") - .join(qc3, on="studyId", how="outer") .join(qc4, on="studyId", how="outer") .join(qc5, on="studyId", how="outer") ) diff --git a/src/gentropy/sumstat_qc_step.py b/src/gentropy/sumstat_qc_step.py index 0c3b7bb14..b5aed905e 100644 --- a/src/gentropy/sumstat_qc_step.py +++ b/src/gentropy/sumstat_qc_step.py @@ -16,6 +16,7 @@ def __init__( gwas_path: str, output_path: str, studyid: str, + pval_threshold: float = 1e-8, ) -> None: """Calculating quality control metrics on the provided GWAS study. @@ -24,13 +25,14 @@ def __init__( gwas_path (str): Path to the GWAS summary statistics. output_path (str): Output path for the QC results. studyid (str): Study ID for the QC. + pval_threshold (float): P-value threshold for the QC. Default is 1e-8. """ gwas = SummaryStatistics.from_parquet(session, path=gwas_path) ( SummaryStatisticsQC.get_quality_control_metrics( - gwas=gwas, limit=100_000_000, min_count=100, n_total=100000 + gwas=gwas, limit=100_000_000, pval_threshold=pval_threshold ) .write.mode(session.write_mode) .parquet(output_path + "/qc_results_" + studyid) diff --git a/tests/gentropy/method/test_qc_of_sumstats.py b/tests/gentropy/method/test_qc_of_sumstats.py index d734fcaef..8f63e6ba2 100644 --- a/tests/gentropy/method/test_qc_of_sumstats.py +++ b/tests/gentropy/method/test_qc_of_sumstats.py @@ -3,7 +3,6 @@ from __future__ import annotations import numpy as np -import pandas as pd import pyspark.sql.functions as f import pytest from pyspark.sql.functions import rand, when @@ -18,9 +17,7 @@ def test_qc_functions( ) -> None: """Test all sumstat qc functions.""" gwas = sample_summary_statistics.sanity_filter() - QC = SummaryStatisticsQC.get_quality_control_metrics( - gwas=gwas, limit=100000, min_count=100, n_total=100000 - ) + QC = SummaryStatisticsQC.get_quality_control_metrics(gwas=gwas, limit=100000) QC = QC.toPandas() assert QC["n_variants"].iloc[0] == 1663 @@ -29,7 +26,6 @@ def test_qc_functions( assert np.round(QC["mean_beta"].iloc[0], 4) == 0.0013 assert np.round(QC["mean_diff_pz"].iloc[0], 6) == 0 assert np.round(QC["se_diff_pz"].iloc[0], 6) == 0 - assert pd.isna(QC["se_N"].iloc[0]) def test_neff_check_eaf( @@ -41,8 +37,8 @@ def test_neff_check_eaf( gwas_df = gwas_df.withColumn("effectAlleleFrequencyFromSource", f.lit(0.5)) gwas._df = gwas_df - QC = SummaryStatisticsQC.get_quality_control_metrics( - gwas=gwas, limit=100000, min_count=100, n_total=100000 + QC = SummaryStatisticsQC.sumstat_n_eff_check( + gwas_for_qc=gwas, limit=100000, min_count=100, n_total=100000 ) QC = QC.toPandas() assert np.round(QC["se_N"].iloc[0], 4) == 0.5586 @@ -59,11 +55,9 @@ def test_several_studyid( ) gwas._df = gwas_df - QC = SummaryStatisticsQC.get_quality_control_metrics( - gwas=gwas, limit=100000, min_count=100, n_total=100000 - ) + QC = SummaryStatisticsQC.get_quality_control_metrics(gwas=gwas, limit=100000) QC = QC.toPandas() - assert QC.shape == (2, 8) + assert QC.shape == (2, 7) def test_sanity_filter_remove_inf_values(