Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Implement Likelihood-Ratio test #178

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/plot_minimal_pydeseq2_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@
# should be a *fitted* :class:`DeseqDataSet <pydeseq2.dds.DeseqDataSet>`
# object.

stat_res = DeseqStats(dds, n_cpus=8)
stat_res = DeseqStats(dds, test="LRT", n_cpus=8)

# %%
# It also has a set of optional keyword arguments (see the :doc:`API documentation
Expand Down
100 changes: 93 additions & 7 deletions pydeseq2/ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List
from typing import Literal
from typing import Optional
from typing import Tuple

# import anndata as ad
import numpy as np
Expand All @@ -18,6 +19,7 @@

from pydeseq2.dds import DeseqDataSet
from pydeseq2.utils import get_num_processes
from pydeseq2.utils import lrt_test
from pydeseq2.utils import make_MA_plot
from pydeseq2.utils import nbinomGLM
from pydeseq2.utils import wald_test
Expand All @@ -36,6 +38,9 @@ class DeseqStats:
dds : DeseqDataSet
DeseqDataSet for which dispersion and LFCs were already estimated.

test : Literal["Wald", "LRT"]
The statistical test to use. One of ``["Wald", "LRT"]``.

contrast : list or None
A list of three strings, in the following format:
``['variable_of_interest', 'tested_level', 'ref_level']``.
Expand Down Expand Up @@ -146,6 +151,7 @@ class DeseqStats:
def __init__(
self,
dds: DeseqDataSet,
test: Literal["wald", "LRT"] = "wald",
contrast: Optional[List[str]] = None,
alpha: float = 0.05,
cooks_filter: bool = True,
Expand All @@ -166,6 +172,10 @@ def __init__(

self.dds = dds

if test not in ("wald", "LRT"):
raise ValueError(f"Available tests are `wald` and `LRT`. Got: {test}.")
self.test = test

self.alpha = alpha
self.cooks_filter = cooks_filter
self.independent_filter = independent_filter
Expand Down Expand Up @@ -207,6 +217,10 @@ def __init__(
"to False."
)

self.p_values: pd.Series
self.statistics: pd.Series
self.SE: pd.Series

def summary(
self,
**kwargs,
Expand Down Expand Up @@ -249,7 +263,11 @@ def summary(
self.lfc_null = lfc_null
self.alt_hypothesis = alt_hypothesis
rerun_summary = True
self.run_wald_test()

if self.test == "wald":
self.run_wald_test()
else:
self.run_likelihood_ratio_test()

if self.cooks_filter:
# Filter p-values based on Cooks outliers
Expand All @@ -268,18 +286,19 @@ def summary(
self.results_df = pd.DataFrame(index=self.dds.var_names)
self.results_df["baseMean"] = self.base_mean
self.results_df["log2FoldChange"] = self.LFC @ self.contrast_vector / np.log(2)
self.results_df["lfcSE"] = self.SE / np.log(2)
if self.test == "wald":
self.results_df["lfcSE"] = self.SE / np.log(2)
self.results_df["stat"] = self.statistics
self.results_df["pvalue"] = self.p_values
self.results_df["padj"] = self.padj

if self.contrast[1] == self.contrast[2] == "":
# The factor is continuous
print(f"Log2 fold change & Wald test p-value: " f"{self.contrast[0]}")
print(f"Log2 fold change & test p-value: " f"{self.contrast[0]}")
else:
# The factor is categorical
print(
f"Log2 fold change & Wald test p-value: "
f"Log2 fold change & test p-value: "
f"{self.contrast[0]} {self.contrast[1]} vs {self.contrast[2]}"
)
display(self.results_df)
Expand Down Expand Up @@ -345,16 +364,83 @@ def run_wald_test(self) -> None:

pvals, stats, se = zip(*res)

self.p_values: pd.Series = pd.Series(pvals, index=self.dds.var_names)
self.statistics: pd.Series = pd.Series(stats, index=self.dds.var_names)
self.SE: pd.Series = pd.Series(se, index=self.dds.var_names)
self.p_values = pd.Series(pvals, index=self.dds.var_names)
self.statistics = pd.Series(stats, index=self.dds.var_names)
self.SE = pd.Series(se, index=self.dds.var_names)

# Account for possible all_zeroes due to outlier refitting in DESeqDataSet
if self.dds.refit_cooks and self.dds.varm["replaced"].sum() > 0:
self.SE.loc[self.dds.new_all_zeroes_genes] = 0.0
self.statistics.loc[self.dds.new_all_zeroes_genes] = 0.0
self.p_values.loc[self.dds.new_all_zeroes_genes] = 1.0

def run_likelihood_ratio_test(self) -> None:
"""Perform a Likelihood Ratio test.

Get gene-wise p-values for gene over/under-expression.
"""

num_genes = self.dds.n_vars
num_vars = self.design_matrix.shape[1]

# XXX: Raise a warning if LFCs are shrunk.

def reduce(
design_matrix: np.ndarray, ridge_factor: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
indices = np.full(design_matrix.shape[1], True, dtype=bool)
indices[self.contrast_idx] = False
return design_matrix[:, indices], ridge_factor[:, indices][indices]

# Set regularization factors.
if self.prior_LFC_var is not None:
ridge_factor = np.diag(1 / self.prior_LFC_var**2)
else:
ridge_factor = np.diag(np.repeat(1e-6, num_vars))

design_matrix = self.design_matrix.values
LFCs = self.LFC.values

reduced_design_matrix, reduced_ridge_factor = reduce(design_matrix, ridge_factor)
self.dds.obsm["reduced_design_matrix"] = reduced_design_matrix

if not self.quiet:
print("Running LRT tests...", file=sys.stderr)
start = time.time()
with parallel_backend("loky", inner_max_num_threads=1):
res = Parallel(
n_jobs=self.n_processes,
verbose=self.joblib_verbosity,
batch_size=self.batch_size,
)(
delayed(lrt_test)(
counts=self.dds.X[:, i],
design_matrix=design_matrix,
reduced_design_matrix=reduced_design_matrix,
size_factors=self.dds.obsm["size_factors"],
disp=self.dds.varm["dispersions"][i],
lfc=LFCs[i],
min_mu=self.dds.min_mu,
ridge_factor=ridge_factor,
reduced_ridge_factor=reduced_ridge_factor,
beta_tol=self.dds.beta_tol,
)
for i in range(num_genes)
)
end = time.time()
if not self.quiet:
print(f"... done in {end-start:.2f} seconds.\n", file=sys.stderr)

pvals, stats = zip(*res)

self.p_values = pd.Series(pvals, index=self.dds.var_names)
self.statistics = pd.Series(stats, index=self.dds.var_names)

# Account for possible all_zeroes due to outlier refitting in DESeqDataSet
if self.dds.refit_cooks and self.dds.varm["replaced"].sum() > 0:
self.statistics.loc[self.dds.new_all_zeroes_genes] = 0.0
self.p_values.loc[self.dds.new_all_zeroes_genes] = 1.0

def lfc_shrink(self, coeff: Optional[str] = None) -> None:
"""LFC shrinkage with an apeGLM prior :cite:p:`DeseqStats-zhu2019heavy`.

Expand Down
92 changes: 92 additions & 0 deletions pydeseq2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from scipy.optimize import minimize # type: ignore
from scipy.special import gammaln # type: ignore
from scipy.special import polygamma # type: ignore
from scipy.stats import chi2 # type: ignore
from scipy.stats import norm # type: ignore
from sklearn.linear_model import LinearRegression # type: ignore

Expand Down Expand Up @@ -979,6 +980,97 @@ def less_abs(lfc_null):
return wald_p_value, wald_statistic, wald_se


def lrt_test(
counts: np.ndarray,
design_matrix: np.ndarray,
reduced_design_matrix: np.ndarray,
size_factors: np.ndarray,
disp: float,
lfc: np.ndarray,
min_mu: float,
ridge_factor: np.ndarray,
reduced_ridge_factor: np.ndarray,
beta_tol: float,
) -> Tuple[float, float]:
"""Run likelihood ratio test for differential expression.

Compute likelihood ratio test statistics and p-values from
dispersion and LFC estimates.

Parameters
----------
counts : ndarray
Raw counts for a given gene.

design_matrix : ndarray
Design matrix.

reduced_design_matrix : ndarray
Reduced design matrix.

size_factors : ndarray
DESeq2 normalization factors.

disp : float
Dispersion estimate.

lfc : ndarray
Log-fold change estimate (in natural log scale).

min_mu : float
Lower bound on estimated means, to ensure numerical stability.
(default: ``0.5``).

ridge_factor : ndarray
Regularization factors.

reduced_ridge_factor : ndarray
Reduced regularization factors.

beta_tol : float
Stopping criterion for IRWLS:
:math:`\vert dev - dev_{old}\vert / \vert dev + 0.1 \vert < \beta_{tol}`.
(default: ``1e-8``).

Returns
-------
lrt_p_value : float
Estimated p-value.

lrt_statistic : float
LRT statistic.
"""

def reg_nb_nll(
beta: np.ndarray, design_matrix: np.ndarray, ridge_factor: np.ndarray
) -> float:
# closure to minimize
mu_ = np.maximum(size_factors * np.exp(design_matrix @ beta), min_mu)
val = nb_nll(counts, mu_, disp) + 0.5 * (ridge_factor @ beta**2).sum()
return -1.0 * val # maximize the likelihood

beta_reduced, *_ = irls_solver(
counts=counts,
size_factors=size_factors,
design_matrix=reduced_design_matrix,
disp=disp,
min_mu=min_mu,
beta_tol=beta_tol,
)

reduced_ll = reg_nb_nll(beta_reduced, reduced_design_matrix, reduced_ridge_factor)
full_ll = reg_nb_nll(lfc, design_matrix, ridge_factor)

lrt_statistic = 2 * (full_ll - reduced_ll)
# df = 1 since contrast_idx is the only variable removed
lrt_p_value = chi2.sf(lrt_statistic, df=1)

print(lrt_p_value)
print(lrt_statistic)

return lrt_p_value, lrt_statistic


def fit_rough_dispersions(
normed_counts: np.ndarray, design_matrix: pd.DataFrame
) -> np.ndarray:
Expand Down
Loading