Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Heat dist #109

Draft
wants to merge 6 commits into
base: single_cell_HD
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions tests/test_heat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest
import torch
from torchcfm.diffusion_distance import HeatKernelKNN, torch_knn_from_data

DEVICES = ["cpu"]
if torch.cuda.is_available():
DEVICES.append("cuda")

def gt_heat_kernel_knn(
data,
t,
k,
):
L = torch_knn_from_data(data, k=k, projection=False, proj_dim=10)
# eigendecomposition
eigvals, eigvecs = torch.linalg.eigh(L)
# compute the heat kernel
heat_kernel = eigvecs @ torch.diag(torch.exp(-t * eigvals)) @ eigvecs.T
heat_kernel = (heat_kernel + heat_kernel.T) / 2
heat_kernel[heat_kernel < 0] = 0.0
return heat_kernel


@pytest.mark.parametrize("t", [0.1, 1.0,])
@pytest.mark.parametrize("order", [10, 30, 50])
@pytest.mark.parametrize("k", [10, 20])
@pytest.mark.parametrize("device", DEVICES)
def test_heat_kernel_knn(t, order, k, device):
tol = 2e-1 if t > 1.0 else 1e-1
data = torch.randn(100, 5)
data = data.to(device)
heat_op = HeatKernelKNN(k=k, t=t, order=order, graph_type="scanpy")
heat_kernel = heat_op(data)

# test if symmetric
assert torch.allclose(heat_kernel, heat_kernel.T)

# test if positive
assert torch.all(heat_kernel >= 0)

# test if the heat kernel is close to the ground truth
gt_heat_kernel = gt_heat_kernel_knn(data, t=t, k=k)
assert torch.allclose(heat_kernel, gt_heat_kernel, atol=tol, rtol=tol)

if __name__ == "__main__":
pytest.main([__file__])
38 changes: 38 additions & 0 deletions torchcfm/cheb_approx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import typing as T
import numpy as np
import torch
from scipy.special import ive


def expm_multiply(
L: torch.Tensor,
X: torch.Tensor,
coeff: torch.Tensor,
eigval: T.Union[torch.Tensor, np.ndarray],
):
"""Matrix exponential with chebyshev polynomial approximation."""

def body(carry, c):
T0, T1, Y = carry
T2 = (2.0 / eigval) * torch.matmul(L, T1) - 2.0 * T1 - T0
Y = Y + c * T2
return (T1, T2, Y)

T0 = X
Y = 0.5 * coeff[0] * T0
T1 = (1.0 / eigval) * torch.matmul(L, X) - T0
Y = Y + coeff[1] * T1

initial_state = (T0, T1, Y)
for c in coeff[2:]:
initial_state = body(initial_state, c)

_, _, Y = initial_state

return Y


@torch.no_grad()
def compute_chebychev_coeff_all(eigval, t, K):
eigval = eigval.detach().cpu()
return 2.0 * ive(torch.arange(0, K + 1, device=eigval.device), -t * eigval)
119 changes: 119 additions & 0 deletions torchcfm/diffusion_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import torch
from torchcfm.cheb_approx import compute_chebychev_coeff_all, expm_multiply

try:
import scanpy as sc
except ImportError:
pass

EPS_LOG = 1e-6
EPS_HEAT = 1e-4

def norm_sym_laplacian(A: torch.Tensor):
deg = A.sum(dim=1)
deg_sqrt_inv = torch.diag(1.0 / torch.sqrt(deg + EPS_LOG))
id = torch.eye(A.shape[0], device=A.device, dtype=A.dtype)
return id - deg_sqrt_inv @ A @ deg_sqrt_inv


def torch_knn_from_data(
data: torch.Tensor, k: int, projection: bool = False, proj_dim: int = 100
):
if projection:
_, _, V = torch.pca_lowrank(data, q=proj_dim, center=True)
data = data @ V
dist = torch.cdist(data, data)
_, indices = torch.topk(dist, k, largest=False)
affinity = torch.zeros(data.shape[0], data.shape[0])
affinity.scatter_(1, indices, 1)
return norm_sym_laplacian(affinity)


def scanpy_knn_from_data(
data: torch.Tensor, k: int, projection: bool = False, proj_dim: int = 100
):
adata = sc.AnnData(data.numpy())
if projection:
sc.pp.pca(adata, n_comps=proj_dim)
sc.pp.neighbors(
adata, n_neighbors=k, use_rep="X_pca" if projection else None
)
return norm_sym_laplacian(
torch.tensor(adata.obsp["connectivities"].toarray(), device=data.device)
)


def var_fn(x, t):
outer = torch.outer(torch.diag(x), torch.ones(x.shape[0]))
vol_approx = (outer + outer.T) * 0.5
return -t * torch.log(x + EPS_LOG) + t * torch.log(vol_approx + EPS_LOG)


class BaseHeatKernel:
def __init__(self, t: float = 1.0, order: int = 30):
self.t = t
self.order = order
self.dist_fn = var_fn
self.graph_fn = None

def __call__(self, data: torch.Tensor):
if self.graph_fn is None:
raise NotImplementedError("graph_fn is not implemented")
L = self.graph_fn(data)
heat_kernel = self.compute_heat_from_laplacian(L)
heat_kernel = self.sym_clip(heat_kernel)
return heat_kernel

def compute_heat_from_laplacian(self, L: torch.Tensor):
n = L.shape[0]
val = torch.linalg.eigvals(L).real
max_eigval = val.max()
cheb_coeff = compute_chebychev_coeff_all(
0.5 * max_eigval, self.t, self.order
)
heat_kernel = expm_multiply(
L, torch.eye(n), cheb_coeff, 0.5 * max_eigval
)
return heat_kernel

def sym_clip(self, heat_kernel: torch.Tensor):
heat_kernel = (heat_kernel + heat_kernel.T) / 2
heat_kernel[heat_kernel < 0] = 0.0 + EPS_HEAT
return heat_kernel

def fit(self, data: torch.Tensor, dist_type: str = "var"):
assert dist_type in self.dist_fn
heat_kernel = self(data)
return self.dist_fn[dist_type](heat_kernel, self.t)


class HeatKernelKNN(BaseHeatKernel):
"""Approximation of the heat kernel with a graph from a k-nearest neighbors affinity matrix.
Uses Chebyshev polynomial approximation.
"""

_is_differentiable = False
_implemented_graph = {
"torch": torch_knn_from_data,
"scanpy": scanpy_knn_from_data,
}

def __init__(
self,
k: int = 10,
order: int = 30,
t: float = 1.0,
projection: bool = False,
proj_dim: int = 100,
graph_type: str = "scanpy",
):
super().__init__(t=t, order=order)
assert (
graph_type in self._implemented_graph
), f"Type must be in {self._implemented_graph}"
self.k = k
self.projection = projection
self.proj_dim = proj_dim
self.graph_fn = lambda x: self._implemented_graph[graph_type](
x, self.k, projection=self.projection, proj_dim=self.proj_dim
)