Skip to content

Commit

Permalink
Add kw args to kernel computations
Browse files Browse the repository at this point in the history
  • Loading branch information
gmeanti committed Sep 20, 2023
1 parent 98e7271 commit eab3abb
Show file tree
Hide file tree
Showing 8 changed files with 624 additions and 139 deletions.
8 changes: 4 additions & 4 deletions falkon/kernels/diff_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def nondiff_params(self) -> Dict[str, Any]:
"""
return self._other_params

def compute(self, X1: torch.Tensor, X2: torch.Tensor, out: torch.Tensor, diag: bool):
return self.core_fn(X1, X2, out, **self.diff_params, diag=diag, **self._other_params)
def compute(self, X1: torch.Tensor, X2: torch.Tensor, out: torch.Tensor, diag: bool, **kwargs):
return self.core_fn(X1, X2, out=out, diag=diag, **kwargs, **self.diff_params, **self._other_params)

def compute_diff(self, X1: torch.Tensor, X2: torch.Tensor, diag: bool):
def compute_diff(self, X1: torch.Tensor, X2: torch.Tensor, diag: bool, **kwargs):
"""
Compute the kernel matrix of ``X1`` and ``X2``. The output should be differentiable with
respect to `X1`, `X2`, and all kernel parameters returned by the :meth:`diff_params` method.
Expand All @@ -110,7 +110,7 @@ def compute_diff(self, X1: torch.Tensor, X2: torch.Tensor, diag: bool):
out : torch.Tensor
The constructed kernel matrix.
"""
return self.core_fn(X1, X2, out=None, diag=diag, **self.diff_params, **self._other_params)
return self.core_fn(X1, X2, out=None, diag=diag, **kwargs, **self.diff_params, **self._other_params)

@abc.abstractmethod
def detach(self) -> "Kernel":
Expand Down
65 changes: 34 additions & 31 deletions falkon/kernels/distance_kernel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Type, Union
from typing import Dict, Optional, Type, Union, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -137,22 +137,12 @@ def _rbf_diag_core(mat1, mat2, out: Optional[torch.Tensor], sigma: torch.Tensor)
return out


def rbf_core(mat1, mat2, out: Optional[torch.Tensor], diag: bool, sigma: torch.Tensor) -> torch.Tensor:
def rbf_core(
mat1: torch.Tensor, mat2: torch.Tensor, out: Optional[torch.Tensor], diag: bool, sigma: torch.Tensor
) -> torch.Tensor:
"""
Note 1: if out is None, then this function will be differentiable wrt all three remaining inputs.
Note 2: this function can deal with batched inputs
Parameters
----------
mat1
mat2
out
diag
sigma
Returns
-------
"""
# Move hparams
sigma = sigma.to(device=mat1.device, dtype=mat1.dtype)
Expand All @@ -172,11 +162,11 @@ def rbf_core(mat1, mat2, out: Optional[torch.Tensor], diag: bool, sigma: torch.T
def rbf_core_sparse(
mat1: SparseTensor,
mat2: SparseTensor,
out: torch.Tensor,
mat1_csr: SparseTensor,
mat2_csr: SparseTensor,
out: torch.Tensor,
diag: bool,
sigma,
sigma: torch.Tensor,
) -> torch.Tensor:
if diag:
return _distancek_diag(mat1, out)
Expand All @@ -189,7 +179,9 @@ def rbf_core_sparse(
return out


def laplacian_core(mat1, mat2, out: Optional[torch.Tensor], diag: bool, sigma):
def laplacian_core(
mat1: torch.Tensor, mat2: torch.Tensor, out: Optional[torch.Tensor], diag: bool, sigma: torch.Tensor
):
if diag:
return _distancek_diag(mat1, out)
# Move hparams
Expand All @@ -215,11 +207,11 @@ def laplacian_core(mat1, mat2, out: Optional[torch.Tensor], diag: bool, sigma):
def laplacian_core_sparse(
mat1: SparseTensor,
mat2: SparseTensor,
out: torch.Tensor,
mat1_csr: SparseTensor,
mat2_csr: SparseTensor,
out: torch.Tensor,
diag: bool,
sigma,
sigma: torch.Tensor,
) -> torch.Tensor:
if diag:
return _distancek_diag(mat1, out)
Expand All @@ -233,7 +225,9 @@ def laplacian_core_sparse(
return out


def matern_core(mat1, mat2, out: Optional[torch.Tensor], diag: bool, sigma, nu):
def matern_core(
mat1: torch.Tensor, mat2: torch.Tensor, out: Optional[torch.Tensor], diag: bool, sigma: torch.Tensor, nu: float
):
if diag:
return _distancek_diag(mat1, out)
# Move hparams
Expand Down Expand Up @@ -280,21 +274,21 @@ def matern_core(mat1, mat2, out: Optional[torch.Tensor], diag: bool, sigma, nu):
def matern_core_sparse(
mat1: SparseTensor,
mat2: SparseTensor,
out: torch.Tensor,
mat1_csr: SparseTensor,
mat2_csr: SparseTensor,
out: torch.Tensor,
diag: bool,
sigma,
nu,
sigma: torch.Tensor,
nu: float,
) -> torch.Tensor:
if diag:
return _distancek_diag(mat1, out)
# Move hparams
sigma = sigma.to(device=mat1.device, dtype=mat1.dtype)
if nu == 0.5:
return laplacian_core_sparse(mat1, mat2, mat1_csr, mat2_csr, out, diag, sigma)
return laplacian_core_sparse(mat1, mat2, out, mat1_csr, mat2_csr, diag, sigma)
elif nu == float("inf"):
return rbf_core_sparse(mat1, mat2, mat1_csr, mat2_csr, out, diag, sigma)
return rbf_core_sparse(mat1, mat2, out, mat1_csr, mat2_csr, diag, sigma)
gamma = 1 / (sigma**2)
out = _sparse_sq_dist(X1_csr=mat1_csr, X2_csr=mat2_csr, X1=mat1, X2=mat2, out=out)
out.mul_(gamma)
Expand Down Expand Up @@ -390,7 +384,7 @@ def __init__(self, sigma: Union[float, torch.Tensor], opt: Optional[FalkonOption
sigma = validate_sigma(sigma)
super().__init__(self.kernel_name, opt, core_fn=GaussianKernel.core_fn, sigma=sigma)

def keops_mmv_impl(self, X1, X2, v, kernel, out, opt: FalkonOptions):
def keops_mmv_impl(self, X1, X2, v, kernel, out, opt, kwargs_m1, kwargs_m2):
formula = "Exp(SqDist(x1 / g, x2 / g) * IntInv(-2)) * v"
aliases = [
"x1 = Vi(%d)" % (X1.shape[1]),
Expand Down Expand Up @@ -422,12 +416,15 @@ def compute_sparse(
X2: SparseTensor,
out: torch.Tensor,
diag: bool,
n_ids: Tuple[int, int],
m_ids: Tuple[int, int],
X1_csr: SparseTensor,
X2_csr: SparseTensor,
**kwargs,
) -> torch.Tensor:
if len(self.sigma) > 1:
raise NotImplementedError("Sparse kernel is only implemented for scalar sigmas.")
return rbf_core_sparse(X1, X2, X1_csr, X2_csr, out, diag, self.sigma)
return rbf_core_sparse(X1, X2, out, X1_csr, X2_csr, diag, self.sigma)

def __repr__(self):
return f"GaussianKernel(sigma={self.sigma})"
Expand Down Expand Up @@ -463,7 +460,7 @@ def __init__(self, sigma: Union[float, torch.Tensor], opt: Optional[FalkonOption

super().__init__(self.kernel_name, opt, core_fn=laplacian_core, sigma=sigma)

def keops_mmv_impl(self, X1, X2, v, kernel, out, opt: FalkonOptions):
def keops_mmv_impl(self, X1, X2, v, kernel, out, opt, kwargs_m1, kwargs_m2):
formula = "Exp(-Sqrt(SqDist(x1 / g, x2 / g))) * v"
aliases = [
"x1 = Vi(%d)" % (X1.shape[1]),
Expand Down Expand Up @@ -495,12 +492,15 @@ def compute_sparse(
X2: SparseTensor,
out: torch.Tensor,
diag: bool,
n_ids: Tuple[int, int],
m_ids: Tuple[int, int],
X1_csr: SparseTensor,
X2_csr: SparseTensor,
**kwargs,
) -> torch.Tensor:
if len(self.sigma) > 1:
raise NotImplementedError("Sparse kernel is only implemented for scalar sigmas.")
return laplacian_core_sparse(X1, X2, X1_csr, X2_csr, out, diag, self.sigma)
return laplacian_core_sparse(X1, X2, out, X1_csr, X2_csr, diag, self.sigma)

def __repr__(self):
return f"LaplacianKernel(sigma={self.sigma})"
Expand Down Expand Up @@ -546,7 +546,7 @@ def __init__(
self.kernel_name = f"{nu:.1f}-matern"
super().__init__(self.kernel_name, opt, core_fn=matern_core, sigma=sigma, nu=nu)

def keops_mmv_impl(self, X1, X2, v, kernel, out, opt: FalkonOptions):
def keops_mmv_impl(self, X1, X2, v, kernel, out, opt, kwargs_m1, kwargs_m2):
if self.nu == 0.5:
formula = "Exp(-Norm2(x1 / s - x2 / s)) * v"
elif self.nu == 1.5:
Expand Down Expand Up @@ -616,12 +616,15 @@ def compute_sparse(
X2: SparseTensor,
out: torch.Tensor,
diag: bool,
n_ids: Tuple[int, int],
m_ids: Tuple[int, int],
X1_csr: SparseTensor,
X2_csr: SparseTensor,
**kwargs,
) -> torch.Tensor:
if len(self.sigma) > 1:
raise NotImplementedError("Sparse kernel is only implemented for scalar sigmas.")
return matern_core_sparse(X1, X2, X1_csr, X2_csr, out, diag, self.sigma, self.nondiff_params["nu"])
return matern_core_sparse(X1, X2, out, X1_csr, X2_csr, diag, self.sigma, self.nondiff_params["nu"])

def __repr__(self):
return f"MaternKernel(sigma={self.sigma}, nu={self.nondiff_params['nu']:.1f})"
Expand Down
21 changes: 8 additions & 13 deletions falkon/kernels/dot_prod_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(
gamma = validate_diff_float(gamma, param_name="gamma")
super().__init__("Linear", opt, linear_core, beta=beta, gamma=gamma)

def keops_mmv_impl(self, X1, X2, v, kernel, out, opt):
def keops_mmv_impl(self, X1, X2, v, kernel, out, opt, kwargs_m1, kwargs_m2):
formula = "(gamma * (X | Y) + beta) * v"
aliases = [
"X = Vi(%d)" % (X1.shape[1]),
Expand Down Expand Up @@ -242,7 +242,7 @@ def __init__(
degree = validate_diff_float(degree, param_name="degree")
super().__init__("Polynomial", opt, polynomial_core, beta=beta, gamma=gamma, degree=degree)

def keops_mmv_impl(self, X1, X2, v, kernel, out, opt):
def keops_mmv_impl(self, X1, X2, v, kernel, out, opt, kwargs_m1, kwargs_m2):
formula = "Powf((gamma * (X | Y) + beta), degree) * v"
aliases = [
"X = Vi(%d)" % (X1.shape[1]),
Expand Down Expand Up @@ -271,15 +271,7 @@ def detach(self) -> "PolynomialKernel":
def compute_sparse(
self, X1: SparseTensor, X2: SparseTensor, out: torch.Tensor, diag: bool, **kwargs
) -> torch.Tensor:
return polynomial_core_sparse(
X1,
X2,
out,
diag,
beta=self.beta,
gamma=self.gamma,
degree=self.degree,
)
return polynomial_core_sparse(X1, X2, out, diag, beta=self.beta, gamma=self.gamma, degree=self.degree)

def __str__(self):
return f"PolynomialKernel(beta={self.beta}, gamma={self.gamma}, degree={self.degree})"
Expand Down Expand Up @@ -310,13 +302,16 @@ class SigmoidKernel(DiffKernel, KeopsKernelMixin):
"""

def __init__(
self, beta: Union[float, torch.Tensor], gamma: Union[float, torch.Tensor], opt: Optional[FalkonOptions] = None
self,
beta: Union[float, torch.Tensor],
gamma: Union[float, torch.Tensor],
opt: Optional[FalkonOptions] = None,
):
beta = validate_diff_float(beta, param_name="beta")
gamma = validate_diff_float(gamma, param_name="gamma")
super().__init__("Sigmoid", opt, sigmoid_core, beta=beta, gamma=gamma)

def keops_mmv_impl(self, X1, X2, v, kernel, out, opt: FalkonOptions):
def keops_mmv_impl(self, X1, X2, v, kernel, out, opt, kwargs_m1, kwargs_m2):
return RuntimeError("SigmoidKernel is not implemented in KeOps")

def _decide_mmv_impl(self, X1, X2, v, opt):
Expand Down
Loading

0 comments on commit eab3abb

Please sign in to comment.