diff --git a/falkon/kernels/diff_kernel.py b/falkon/kernels/diff_kernel.py index dfe93f4..69991cd 100644 --- a/falkon/kernels/diff_kernel.py +++ b/falkon/kernels/diff_kernel.py @@ -87,10 +87,10 @@ def nondiff_params(self) -> Dict[str, Any]: """ return self._other_params - def compute(self, X1: torch.Tensor, X2: torch.Tensor, out: torch.Tensor, diag: bool): - return self.core_fn(X1, X2, out, **self.diff_params, diag=diag, **self._other_params) + def compute(self, X1: torch.Tensor, X2: torch.Tensor, out: torch.Tensor, diag: bool, **kwargs): + return self.core_fn(X1, X2, out=out, diag=diag, **kwargs, **self.diff_params, **self._other_params) - def compute_diff(self, X1: torch.Tensor, X2: torch.Tensor, diag: bool): + def compute_diff(self, X1: torch.Tensor, X2: torch.Tensor, diag: bool, **kwargs): """ Compute the kernel matrix of ``X1`` and ``X2``. The output should be differentiable with respect to `X1`, `X2`, and all kernel parameters returned by the :meth:`diff_params` method. @@ -110,7 +110,7 @@ def compute_diff(self, X1: torch.Tensor, X2: torch.Tensor, diag: bool): out : torch.Tensor The constructed kernel matrix. """ - return self.core_fn(X1, X2, out=None, diag=diag, **self.diff_params, **self._other_params) + return self.core_fn(X1, X2, out=None, diag=diag, **kwargs, **self.diff_params, **self._other_params) @abc.abstractmethod def detach(self) -> "Kernel": diff --git a/falkon/kernels/distance_kernel.py b/falkon/kernels/distance_kernel.py index 46eb898..f3940e2 100644 --- a/falkon/kernels/distance_kernel.py +++ b/falkon/kernels/distance_kernel.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Type, Union +from typing import Dict, Optional, Type, Union, Tuple import numpy as np import torch @@ -137,22 +137,12 @@ def _rbf_diag_core(mat1, mat2, out: Optional[torch.Tensor], sigma: torch.Tensor) return out -def rbf_core(mat1, mat2, out: Optional[torch.Tensor], diag: bool, sigma: torch.Tensor) -> torch.Tensor: +def rbf_core( + mat1: torch.Tensor, mat2: torch.Tensor, out: Optional[torch.Tensor], diag: bool, sigma: torch.Tensor +) -> torch.Tensor: """ Note 1: if out is None, then this function will be differentiable wrt all three remaining inputs. Note 2: this function can deal with batched inputs - - Parameters - ---------- - mat1 - mat2 - out - diag - sigma - - Returns - ------- - """ # Move hparams sigma = sigma.to(device=mat1.device, dtype=mat1.dtype) @@ -172,11 +162,11 @@ def rbf_core(mat1, mat2, out: Optional[torch.Tensor], diag: bool, sigma: torch.T def rbf_core_sparse( mat1: SparseTensor, mat2: SparseTensor, + out: torch.Tensor, mat1_csr: SparseTensor, mat2_csr: SparseTensor, - out: torch.Tensor, diag: bool, - sigma, + sigma: torch.Tensor, ) -> torch.Tensor: if diag: return _distancek_diag(mat1, out) @@ -189,7 +179,9 @@ def rbf_core_sparse( return out -def laplacian_core(mat1, mat2, out: Optional[torch.Tensor], diag: bool, sigma): +def laplacian_core( + mat1: torch.Tensor, mat2: torch.Tensor, out: Optional[torch.Tensor], diag: bool, sigma: torch.Tensor +): if diag: return _distancek_diag(mat1, out) # Move hparams @@ -215,11 +207,11 @@ def laplacian_core(mat1, mat2, out: Optional[torch.Tensor], diag: bool, sigma): def laplacian_core_sparse( mat1: SparseTensor, mat2: SparseTensor, + out: torch.Tensor, mat1_csr: SparseTensor, mat2_csr: SparseTensor, - out: torch.Tensor, diag: bool, - sigma, + sigma: torch.Tensor, ) -> torch.Tensor: if diag: return _distancek_diag(mat1, out) @@ -233,7 +225,9 @@ def laplacian_core_sparse( return out -def matern_core(mat1, mat2, out: Optional[torch.Tensor], diag: bool, sigma, nu): +def matern_core( + mat1: torch.Tensor, mat2: torch.Tensor, out: Optional[torch.Tensor], diag: bool, sigma: torch.Tensor, nu: float +): if diag: return _distancek_diag(mat1, out) # Move hparams @@ -280,21 +274,21 @@ def matern_core(mat1, mat2, out: Optional[torch.Tensor], diag: bool, sigma, nu): def matern_core_sparse( mat1: SparseTensor, mat2: SparseTensor, + out: torch.Tensor, mat1_csr: SparseTensor, mat2_csr: SparseTensor, - out: torch.Tensor, diag: bool, - sigma, - nu, + sigma: torch.Tensor, + nu: float, ) -> torch.Tensor: if diag: return _distancek_diag(mat1, out) # Move hparams sigma = sigma.to(device=mat1.device, dtype=mat1.dtype) if nu == 0.5: - return laplacian_core_sparse(mat1, mat2, mat1_csr, mat2_csr, out, diag, sigma) + return laplacian_core_sparse(mat1, mat2, out, mat1_csr, mat2_csr, diag, sigma) elif nu == float("inf"): - return rbf_core_sparse(mat1, mat2, mat1_csr, mat2_csr, out, diag, sigma) + return rbf_core_sparse(mat1, mat2, out, mat1_csr, mat2_csr, diag, sigma) gamma = 1 / (sigma**2) out = _sparse_sq_dist(X1_csr=mat1_csr, X2_csr=mat2_csr, X1=mat1, X2=mat2, out=out) out.mul_(gamma) @@ -390,7 +384,7 @@ def __init__(self, sigma: Union[float, torch.Tensor], opt: Optional[FalkonOption sigma = validate_sigma(sigma) super().__init__(self.kernel_name, opt, core_fn=GaussianKernel.core_fn, sigma=sigma) - def keops_mmv_impl(self, X1, X2, v, kernel, out, opt: FalkonOptions): + def keops_mmv_impl(self, X1, X2, v, kernel, out, opt, kwargs_m1, kwargs_m2): formula = "Exp(SqDist(x1 / g, x2 / g) * IntInv(-2)) * v" aliases = [ "x1 = Vi(%d)" % (X1.shape[1]), @@ -422,12 +416,15 @@ def compute_sparse( X2: SparseTensor, out: torch.Tensor, diag: bool, + n_ids: Tuple[int, int], + m_ids: Tuple[int, int], X1_csr: SparseTensor, X2_csr: SparseTensor, + **kwargs, ) -> torch.Tensor: if len(self.sigma) > 1: raise NotImplementedError("Sparse kernel is only implemented for scalar sigmas.") - return rbf_core_sparse(X1, X2, X1_csr, X2_csr, out, diag, self.sigma) + return rbf_core_sparse(X1, X2, out, X1_csr, X2_csr, diag, self.sigma) def __repr__(self): return f"GaussianKernel(sigma={self.sigma})" @@ -463,7 +460,7 @@ def __init__(self, sigma: Union[float, torch.Tensor], opt: Optional[FalkonOption super().__init__(self.kernel_name, opt, core_fn=laplacian_core, sigma=sigma) - def keops_mmv_impl(self, X1, X2, v, kernel, out, opt: FalkonOptions): + def keops_mmv_impl(self, X1, X2, v, kernel, out, opt, kwargs_m1, kwargs_m2): formula = "Exp(-Sqrt(SqDist(x1 / g, x2 / g))) * v" aliases = [ "x1 = Vi(%d)" % (X1.shape[1]), @@ -495,12 +492,15 @@ def compute_sparse( X2: SparseTensor, out: torch.Tensor, diag: bool, + n_ids: Tuple[int, int], + m_ids: Tuple[int, int], X1_csr: SparseTensor, X2_csr: SparseTensor, + **kwargs, ) -> torch.Tensor: if len(self.sigma) > 1: raise NotImplementedError("Sparse kernel is only implemented for scalar sigmas.") - return laplacian_core_sparse(X1, X2, X1_csr, X2_csr, out, diag, self.sigma) + return laplacian_core_sparse(X1, X2, out, X1_csr, X2_csr, diag, self.sigma) def __repr__(self): return f"LaplacianKernel(sigma={self.sigma})" @@ -546,7 +546,7 @@ def __init__( self.kernel_name = f"{nu:.1f}-matern" super().__init__(self.kernel_name, opt, core_fn=matern_core, sigma=sigma, nu=nu) - def keops_mmv_impl(self, X1, X2, v, kernel, out, opt: FalkonOptions): + def keops_mmv_impl(self, X1, X2, v, kernel, out, opt, kwargs_m1, kwargs_m2): if self.nu == 0.5: formula = "Exp(-Norm2(x1 / s - x2 / s)) * v" elif self.nu == 1.5: @@ -616,12 +616,15 @@ def compute_sparse( X2: SparseTensor, out: torch.Tensor, diag: bool, + n_ids: Tuple[int, int], + m_ids: Tuple[int, int], X1_csr: SparseTensor, X2_csr: SparseTensor, + **kwargs, ) -> torch.Tensor: if len(self.sigma) > 1: raise NotImplementedError("Sparse kernel is only implemented for scalar sigmas.") - return matern_core_sparse(X1, X2, X1_csr, X2_csr, out, diag, self.sigma, self.nondiff_params["nu"]) + return matern_core_sparse(X1, X2, out, X1_csr, X2_csr, diag, self.sigma, self.nondiff_params["nu"]) def __repr__(self): return f"MaternKernel(sigma={self.sigma}, nu={self.nondiff_params['nu']:.1f})" diff --git a/falkon/kernels/dot_prod_kernel.py b/falkon/kernels/dot_prod_kernel.py index d18e01c..e7f4334 100644 --- a/falkon/kernels/dot_prod_kernel.py +++ b/falkon/kernels/dot_prod_kernel.py @@ -174,7 +174,7 @@ def __init__( gamma = validate_diff_float(gamma, param_name="gamma") super().__init__("Linear", opt, linear_core, beta=beta, gamma=gamma) - def keops_mmv_impl(self, X1, X2, v, kernel, out, opt): + def keops_mmv_impl(self, X1, X2, v, kernel, out, opt, kwargs_m1, kwargs_m2): formula = "(gamma * (X | Y) + beta) * v" aliases = [ "X = Vi(%d)" % (X1.shape[1]), @@ -242,7 +242,7 @@ def __init__( degree = validate_diff_float(degree, param_name="degree") super().__init__("Polynomial", opt, polynomial_core, beta=beta, gamma=gamma, degree=degree) - def keops_mmv_impl(self, X1, X2, v, kernel, out, opt): + def keops_mmv_impl(self, X1, X2, v, kernel, out, opt, kwargs_m1, kwargs_m2): formula = "Powf((gamma * (X | Y) + beta), degree) * v" aliases = [ "X = Vi(%d)" % (X1.shape[1]), @@ -271,15 +271,7 @@ def detach(self) -> "PolynomialKernel": def compute_sparse( self, X1: SparseTensor, X2: SparseTensor, out: torch.Tensor, diag: bool, **kwargs ) -> torch.Tensor: - return polynomial_core_sparse( - X1, - X2, - out, - diag, - beta=self.beta, - gamma=self.gamma, - degree=self.degree, - ) + return polynomial_core_sparse(X1, X2, out, diag, beta=self.beta, gamma=self.gamma, degree=self.degree) def __str__(self): return f"PolynomialKernel(beta={self.beta}, gamma={self.gamma}, degree={self.degree})" @@ -310,13 +302,16 @@ class SigmoidKernel(DiffKernel, KeopsKernelMixin): """ def __init__( - self, beta: Union[float, torch.Tensor], gamma: Union[float, torch.Tensor], opt: Optional[FalkonOptions] = None + self, + beta: Union[float, torch.Tensor], + gamma: Union[float, torch.Tensor], + opt: Optional[FalkonOptions] = None, ): beta = validate_diff_float(beta, param_name="beta") gamma = validate_diff_float(gamma, param_name="gamma") super().__init__("Sigmoid", opt, sigmoid_core, beta=beta, gamma=gamma) - def keops_mmv_impl(self, X1, X2, v, kernel, out, opt: FalkonOptions): + def keops_mmv_impl(self, X1, X2, v, kernel, out, opt, kwargs_m1, kwargs_m2): return RuntimeError("SigmoidKernel is not implemented in KeOps") def _decide_mmv_impl(self, X1, X2, v, opt): diff --git a/falkon/kernels/keops_helpers.py b/falkon/kernels/keops_helpers.py index a4470f0..30a22e6 100644 --- a/falkon/kernels/keops_helpers.py +++ b/falkon/kernels/keops_helpers.py @@ -1,6 +1,6 @@ import abc import functools -from typing import List, Optional, Union +from typing import List, Optional, Union, Dict import torch @@ -9,13 +9,6 @@ from falkon.sparse import SparseTensor from falkon.utils.switches import decide_keops -try: - from falkon.mmv_ops.keops import run_keops_mmv - - _has_keops = True -except ModuleNotFoundError: - _has_keops = False - __all__ = ( "should_use_keops", "KeopsKernelMixin", @@ -23,7 +16,9 @@ def should_use_keops( - T1: Union[torch.Tensor, SparseTensor], T2: Union[torch.Tensor, SparseTensor], opt: KeopsOptions + T1: Union[torch.Tensor, SparseTensor], + T2: Union[torch.Tensor, SparseTensor], + opt: KeopsOptions, ) -> bool: """Check whether the conditions to use KeOps for mmv operations are satisfied @@ -113,10 +108,13 @@ def keops_mmv( out The computed kernel matrix between ``X1`` and ``X2``, multiplied by vector ``v``. """ - if not _has_keops: + try: + from falkon.mmv_ops.keops import run_keops_mmv + except ModuleNotFoundError as e: raise ModuleNotFoundError( "Module 'pykeops' is not properly installed. Please install 'pykeops' before running 'keops_mmv'." - ) + ) from e + if other_vars is None: other_vars = [] return run_keops_mmv( @@ -132,7 +130,20 @@ def keops_mmv( opt=opt, ) - def keops_dmmv_helper(self, X1, X2, v, w, kernel, out, differentiable, opt, mmv_fn): + def keops_dmmv_helper( + self, + X1, + X2, + v, + w, + kernel, + out, + differentiable, + opt, + mmv_fn, + kwargs_m1: Optional[Dict[str, torch.Tensor]] = None, + kwargs_m2: Optional[Dict[str, torch.Tensor]] = None, + ): r""" performs fnc(X1*X2', X1, X2)' * ( fnc(X1*X2', X1, X2) * v + w ) @@ -159,6 +170,14 @@ def keops_dmmv_helper(self, X1, X2, v, w, kernel, out, differentiable, opt, mmv_ mmv_fn : Callable The function which performs the mmv operation. Two mmv operations are (usually) needed for a dmmv operation. + kwargs_m1 + Keyword arguments containing tensors which should be split along with ``X1``. + For example this could be a set of indices corresponding to ``X1``, which are then + correctly split and available in the kernel computation. + kwargs_m2 + Keyword arguments containing tensors which should be split along with ``X2``. + For example this could be a set of indices corresponding to ``X2``, which are then + correctly split and available in the kernel computation. Notes ------ @@ -168,17 +187,62 @@ def keops_dmmv_helper(self, X1, X2, v, w, kernel, out, differentiable, opt, mmv_ """ if v is not None and w is not None: - out1 = mmv_fn(X1, X2, v, kernel, out=None, opt=opt) + out1 = mmv_fn( + X1, + X2, + v, + kernel, + out=None, + opt=opt, + kwargs_m1=kwargs_m1, + kwargs_m2=kwargs_m2, + ) if differentiable: out1 = out1.add(w) else: out1.add_(w) - return mmv_fn(X2, X1, out1, kernel, out=out, opt=opt) + return mmv_fn( + X2, + X1, + out1, + kernel, + out=out, + opt=opt, + kwargs_m1=kwargs_m1, + kwargs_m2=kwargs_m2, + ) elif v is None: - return mmv_fn(X2, X1, w, kernel, out=out, opt=opt) + return mmv_fn( + X2, + X1, + w, + kernel, + out=out, + opt=opt, + kwargs_m1=kwargs_m1, + kwargs_m2=kwargs_m2, + ) elif w is None: - out1 = mmv_fn(X1, X2, v, kernel, out=None, opt=opt) - return mmv_fn(X2, X1, out1, kernel, out=out, opt=opt) + out1 = mmv_fn( + X1, + X2, + v, + kernel, + out=None, + opt=opt, + kwargs_m1=kwargs_m1, + kwargs_m2=kwargs_m2, + ) + return mmv_fn( + X2, + X1, + out1, + kernel, + out=out, + opt=opt, + kwargs_m1=kwargs_m1, + kwargs_m2=kwargs_m2, + ) # noinspection PyUnusedLocal def keops_can_handle_mm(self, X1, X2, opt) -> bool: @@ -217,7 +281,17 @@ def _decide_dmmv_impl(self, X1, X2, v, w, opt: FalkonOptions): return super()._decide_dmmv_impl(X1, X2, v, w, opt) @abc.abstractmethod - def keops_mmv_impl(self, X1, X2, v, kernel, out, opt: FalkonOptions): + def keops_mmv_impl( + self, + X1, + X2, + v, + kernel, + out, + opt: FalkonOptions, + kwargs_m1: Optional[Dict[str, torch.Tensor]], + kwargs_m2: Optional[Dict[str, torch.Tensor]], + ): """Implementation of the KeOps formula to compute a kernel-vector product. Parameters @@ -239,6 +313,14 @@ def keops_mmv_impl(self, X1, X2, v, kernel, out, opt: FalkonOptions): opt : FalkonOptions Options to be used for computing the operation. Useful are the memory size options, CUDA options and KeOps options. + kwargs_m1 + Keyword arguments containing tensors which should be split along with ``X1``. + For example this could be a set of indices corresponding to ``X1``, which are then + correctly split and available in the kernel computation. + kwargs_m2 + Keyword arguments containing tensors which should be split along with ``X2``. + For example this could be a set of indices corresponding to ``X2``, which are then + correctly split and available in the kernel computation. Returns ------- diff --git a/falkon/kernels/kernel.py b/falkon/kernels/kernel.py index ef63619..1efad27 100644 --- a/falkon/kernels/kernel.py +++ b/falkon/kernels/kernel.py @@ -139,6 +139,8 @@ def __call__( diag: bool = False, out: Optional[torch.Tensor] = None, opt: Optional[FalkonOptions] = None, + kwargs_m1: Optional[Dict[str, torch.Tensor]] = None, + kwargs_m2: Optional[Dict[str, torch.Tensor]] = None, ): """Compute the kernel matrix between ``X1`` and ``X2`` @@ -158,6 +160,14 @@ def __call__( opt : Optional[FalkonOptions] Options to be used for computing the operation. Useful are the memory size options and CUDA options. + kwargs_m1 + Keyword arguments containing tensors which should be split along with ``m1``. + For example this could be a set of indices corresponding to ``m1``, which are then + correctly split and available in the kernel computation. + kwargs_m2 + Keyword arguments containing tensors which should be split along with ``m2``. + For example this could be a set of indices corresponding to ``m2``, which are then + correctly split and available in the kernel computation. Returns ------- @@ -170,7 +180,16 @@ def __call__( if opt is not None: params = dataclasses.replace(self.params, **dataclasses.asdict(opt)) mm_impl = self._decide_mm_impl(X1, X2, diag, params) - return mm_impl(self, params, out, diag, X1, X2) + return mm_impl( + kernel=self, + opt=params, + out=out, + diag=diag, + X1=X1, + X2=X2, + kwargs_m1=kwargs_m1, + kwargs_m2=kwargs_m2, + ) def _decide_mm_impl(self, X1: torch.Tensor, X2: torch.Tensor, diag: bool, opt: FalkonOptions): """Choose which `mm` function to use for this data. @@ -211,6 +230,8 @@ def mmv( v: torch.Tensor, out: Optional[torch.Tensor] = None, opt: Optional[FalkonOptions] = None, + kwargs_m1: Optional[Dict[str, torch.Tensor]] = None, + kwargs_m2: Optional[Dict[str, torch.Tensor]] = None, ): # noinspection PyShadowingNames """Compute matrix-vector multiplications where the matrix is the current kernel. @@ -232,6 +253,14 @@ def mmv( opt : Optional[FalkonOptions] Options to be used for computing the operation. Useful are the memory size options and CUDA options. + kwargs_m1 + Keyword arguments containing tensors which should be split along with ``m1``. + For example this could be a set of indices corresponding to ``m1``, which are then + correctly split and available in the kernel computation. + kwargs_m2 + Keyword arguments containing tensors which should be split along with ``m2``. + For example this could be a set of indices corresponding to ``m2``, which are then + correctly split and available in the kernel computation. Returns ------- @@ -255,7 +284,16 @@ def mmv( if opt is not None: params = dataclasses.replace(self.params, **dataclasses.asdict(opt)) mmv_impl = self._decide_mmv_impl(X1, X2, v, params) - return mmv_impl(X1, X2, v, self, out, params) + return mmv_impl( + X1=X1, + X2=X2, + v=v, + kernel=self, + out=out, + opt=params, + kwargs_m1=kwargs_m1, + kwargs_m2=kwargs_m2, + ) def _decide_mmv_impl( self, @@ -303,6 +341,8 @@ def dmmv( w: Optional[torch.Tensor], out: Optional[torch.Tensor] = None, opt: Optional[FalkonOptions] = None, + kwargs_m1: Optional[Dict[str, torch.Tensor]] = None, + kwargs_m2: Optional[Dict[str, torch.Tensor]] = None, ): # noinspection PyShadowingNames """Compute double matrix-vector multiplications where the matrix is the current kernel. @@ -332,6 +372,14 @@ def dmmv( opt : Optional[FalkonOptions] Options to be used for computing the operation. Useful are the memory size options and CUDA options. + kwargs_m1 + Keyword arguments containing tensors which should be split along with ``X1``. + For example this could be a set of indices corresponding to ``X1``, which are then + correctly split and available in the kernel computation. + kwargs_m2 + Keyword arguments containing tensors which should be split along with ``X2``. + For example this could be a set of indices corresponding to ``X2``, which are then + correctly split and available in the kernel computation. Returns ------- @@ -355,7 +403,18 @@ def dmmv( if opt is not None: params = dataclasses.replace(self.params, **dataclasses.asdict(opt)) dmmv_impl = self._decide_dmmv_impl(X1, X2, v, w, params) - return dmmv_impl(X1, X2, v, w, self, out, False, params) + return dmmv_impl( + X1=X1, + X2=X2, + v=v, + w=w, + kernel=self, + out=out, + differentiable=False, + opt=params, + kwargs_m1=kwargs_m1, + kwargs_m2=kwargs_m2, + ) def _decide_dmmv_impl( self, @@ -400,7 +459,14 @@ def _decide_dmmv_impl( return fdmmv @abstractmethod - def compute(self, X1: torch.Tensor, X2: torch.Tensor, out: torch.Tensor, diag: bool) -> torch.Tensor: + def compute( + self, + X1: torch.Tensor, + X2: torch.Tensor, + out: torch.Tensor, + diag: bool, + **kwargs, + ) -> torch.Tensor: """ Compute the kernel matrix of ``X1`` and ``X2`` - without regards for differentiability. @@ -418,6 +484,8 @@ def compute(self, X1: torch.Tensor, X2: torch.Tensor, out: torch.Tensor, diag: b diag : bool If true, ``X1`` and ``X2`` have the same shape, and only the diagonal of ``k(X1, X2)`` is to be computed and stored in ``out``. Otherwise compute the full kernel matrix. + kwargs + Additional keyword arguments which may be used in computing the kernel values Returns ------- @@ -428,7 +496,12 @@ def compute(self, X1: torch.Tensor, X2: torch.Tensor, out: torch.Tensor, diag: b @abstractmethod def compute_sparse( - self, X1: SparseTensor, X2: SparseTensor, out: torch.Tensor, diag: bool, **kwargs + self, + X1: SparseTensor, + X2: SparseTensor, + out: torch.Tensor, + diag: bool, + **kwargs, ) -> torch.Tensor: """ Compute the kernel matrix of ``X1`` and ``X2`` which are two sparse matrices, storing the output @@ -450,13 +523,13 @@ def compute_sparse( the keyword arguments passed by the :func:`falkon.mmv_ops.fmmv.sparse_mmv_run_thread` and :func:`falkon.mmv_ops.fmm.sparse_mm_run_thread` functions are: - - X1_csr : the X1 matrix in CSR format - - X2_csr : the X2 matrix in CSR format + - ``X1_csr`` : the ``X1`` matrix in CSR format + - ``X2_csr`` : the ``X2`` matrix in CSR format Returns ------- out : torch.Tensor - The kernel matrix. Should use the same underlying storage as the parameter `out`. + The kernel matrix. Should use the same underlying storage as the parameter ``out``. """ pass diff --git a/falkon/mmv_ops/fmm.py b/falkon/mmv_ops/fmm.py index 699497b..85f09a7 100644 --- a/falkon/mmv_ops/fmm.py +++ b/falkon/mmv_ops/fmm.py @@ -1,12 +1,18 @@ from contextlib import ExitStack -from dataclasses import dataclass -from typing import Optional, Union +from dataclasses import dataclass, field +from typing import Optional, Union, Dict import torch import torch.cuda as tcd import falkon -from falkon.mmv_ops.utils import _call_direct, _check_contiguity, _extract_flat, _get_gpu_info, _start_wait_processes +from falkon.mmv_ops.utils import ( + _call_direct, + _check_contiguity, + _extract_flat, + _get_gpu_info, + _start_wait_processes, +) from falkon.options import BaseOptions from falkon.sparse.sparse_tensor import SparseTensor from falkon.utils.device_copy import copy @@ -24,6 +30,8 @@ class ArgsFmm: max_mem: float differentiable: bool num_streams: int = 1 + kwargs_m1: Dict[str, torch.Tensor] = field(default_factory=dict) + kwargs_m2: Dict[str, torch.Tensor] = field(default_factory=dict) def mm_run_starter(proc_idx, queue, device_id): @@ -124,11 +132,47 @@ def mm_run_starter(proc_idx, queue, device_id): # Run if differentiable: - return mm_diff_run_thread(X1, X2, out, kernel, n, m, computation_dtype, dev, tid=proc_idx) + return mm_diff_run_thread( + X1, + X2, + out, + kernel, + n, + m, + computation_dtype, + dev, + tid=proc_idx, + kwargs_m1=a.kwargs_m1, + kwargs_m2=a.kwargs_m2, + ) elif is_sparse: - return sparse_mm_run_thread(X1, X2, out, kernel, n, m, computation_dtype, dev, tid=proc_idx) + return sparse_mm_run_thread( + X1, + X2, + out, + kernel, + n, + m, + computation_dtype, + dev, + tid=proc_idx, + kwargs_m1=a.kwargs_m1, + kwargs_m2=a.kwargs_m2, + ) else: - return mm_run_thread(X1, X2, out, kernel, n, m, computation_dtype, dev, tid=proc_idx) + return mm_run_thread( + X1, + X2, + out, + kernel, + n, + m, + computation_dtype, + dev, + tid=proc_idx, + kwargs_m1=a.kwargs_m1, + kwargs_m2=a.kwargs_m2, + ) def sparse_mm_run_thread( @@ -141,6 +185,8 @@ def sparse_mm_run_thread( comp_dt: torch.dtype, dev: torch.device, tid: int, + kwargs_m1: Dict[str, torch.Tensor], + kwargs_m2: Dict[str, torch.Tensor], ): """Inner loop to compute (part of) a kernel matrix for two sparse input tensors @@ -166,6 +212,14 @@ def sparse_mm_run_thread( Device on which to run the calculations tid Thread ID. If on the main thread this will be -1 + kwargs_m1 + Keyword arguments containing tensors which should be split along with ``m1``. + For example this could be a set of indices corresponding to ``m1``, which are then + correctly split and available in the kernel computation. + kwargs_m2 + Keyword arguments containing tensors which should be split along with ``m2``. + For example this could be a set of indices corresponding to ``m2``, which are then + correctly split and available in the kernel computation. Returns ------- @@ -193,6 +247,7 @@ def sparse_mm_run_thread( for j in range(0, M, m): lenj = min(m, M - j) + c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2} c_m2 = m2.narrow_rows(j, lenj).to(dtype=comp_dt) # On CUDA the second argument to apply (a Sparse*Sparse multiplication) must be @@ -209,6 +264,7 @@ def sparse_mm_run_thread( for i in range(0, N, n): leni = min(n, N - i) + c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1} c_m1 = m1.narrow_rows(i, leni).to(dtype=comp_dt) if dev.type == "cuda": @@ -221,11 +277,25 @@ def sparse_mm_run_thread( else: c_dev_out = out[i : i + leni, j : j + lenj] c_dev_out.fill_(0.0) - c_dev_out = kernel.compute_sparse(c_dev_m1, c_dev_m2, c_dev_out, diag=False, X1_csr=c_m1, X2_csr=c_m2) + c_dev_out = kernel.compute_sparse( + c_dev_m1, + c_dev_m2, + c_dev_out, + diag=False, + X1_csr=c_m1, + X2_csr=c_m2, + **c_kwargs_m1, + **c_kwargs_m2, + ) # Copy back to host if has_gpu_bufs: - copy(c_dev_out, out[i : i + leni, j : j + lenj], non_blocking=True, allow_dtype_change=True) + copy( + c_dev_out, + out[i : i + leni, j : j + lenj], + non_blocking=True, + allow_dtype_change=True, + ) if tid != -1 and stream is not None: stream.synchronize() return out @@ -241,6 +311,8 @@ def mm_run_thread( comp_dt: torch.dtype, dev: torch.device, tid: int, + kwargs_m1: Dict[str, torch.Tensor], + kwargs_m2: Dict[str, torch.Tensor], ): is_ooc = dev.type != m1.device.type change_dtype = comp_dt != m1.dtype @@ -270,17 +342,29 @@ def mm_run_thread( for i in range(0, N, n): leni = min(n, N - i) + c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1} if has_gpu_bufs: - c_dev_m1 = copy(m1[i : i + leni, :], dev_m1[:leni, :], non_blocking=True, allow_dtype_change=True) + c_dev_m1 = copy( + m1[i : i + leni, :], + dev_m1[:leni, :], + non_blocking=True, + allow_dtype_change=True, + ) else: c_dev_m1 = m1[i : i + leni, :] for j in range(0, M, m): lenj = min(m, M - j) + c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2} if has_gpu_bufs: - c_dev_m2 = copy(m2[j : j + lenj, :], dev_m2[:lenj, :], non_blocking=True, allow_dtype_change=True) + c_dev_m2 = copy( + m2[j : j + lenj, :], + dev_m2[:lenj, :], + non_blocking=True, + allow_dtype_change=True, + ) c_dev_out = dev_nm[:leni, :lenj] else: c_dev_m2 = m2[j : j + lenj, :] @@ -288,11 +372,23 @@ def mm_run_thread( c_dev_out.fill_(0.0) # Compute kernel sub-matrix - kernel.compute(c_dev_m1, c_dev_m2, c_dev_out, diag=False) + kernel.compute( + c_dev_m1, + c_dev_m2, + c_dev_out, + diag=False, + **c_kwargs_m1, + **c_kwargs_m2, + ) # Copy back to host if has_gpu_bufs: - copy(c_dev_out, out[i : i + leni, j : j + lenj], non_blocking=True, allow_dtype_change=True) + copy( + c_dev_out, + out[i : i + leni, j : j + lenj], + non_blocking=True, + allow_dtype_change=True, + ) if tid != -1 and stream is not None: stream.synchronize() return out @@ -302,12 +398,14 @@ def mm_diff_run_thread( m1: torch.Tensor, m2: torch.Tensor, out: torch.Tensor, - kernel: "falkon.kernels.Kernel", + kernel: "falkon.kernels.DiffKernel", n: int, m: int, comp_dt: torch.dtype, dev: torch.device, tid: int, + kwargs_m1: Dict[str, torch.Tensor], + kwargs_m2: Dict[str, torch.Tensor], ): N, D = m1.shape M = m2.shape[0] @@ -323,11 +421,21 @@ def mm_diff_run_thread( for i in range(0, N, n): leni = min(n, N - i) + c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1} + c_dev_m1 = m1[i : i + leni, :].to(device=dev, dtype=comp_dt, non_blocking=True, copy=False) for j in range(0, M, m): lenj = min(m, M - j) + c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2} + c_dev_m2 = m2[j : j + lenj, :].to(device=dev, dtype=comp_dt, non_blocking=True, copy=False) - c_dev_out = kernel.compute_diff(c_dev_m1, c_dev_m2, diag=False) + c_dev_out = kernel.compute_diff( + c_dev_m1, + c_dev_m2, + diag=False, + **c_kwargs_m1, + **c_kwargs_m2, + ) c_out = c_dev_out.to(device=out.device, dtype=out.dtype, non_blocking=False, copy=False) bwd_out = bwd_out + c_out.mul(out[i : i + leni, j : j + lenj]).sum() if tid != -1 and stream is not None: @@ -337,10 +445,20 @@ def mm_diff_run_thread( # noinspection PyMethodOverriding class KernelMmFnFull(torch.autograd.Function): - NUM_NON_DIFF_INPUTS = 4 + NUM_NON_DIFF_INPUTS = 6 @staticmethod - def run_cpu_cpu(X1, X2, out, kernel, dtype, options, diff): + def run_cpu_cpu( + X1, + X2, + out, + kernel, + dtype, + options, + diff, + kwargs_m1: Optional[Dict[str, torch.Tensor]], + kwargs_m2: Optional[Dict[str, torch.Tensor]], + ): args = ArgsFmm( X1=X1, X2=X2, @@ -350,12 +468,24 @@ def run_cpu_cpu(X1, X2, out, kernel, dtype, options, diff): max_mem=options.max_cpu_mem, num_streams=1, differentiable=diff, + kwargs_m1=kwargs_m1 or {}, + kwargs_m2=kwargs_m2 or {}, ) out = _call_direct(mm_run_starter, (args, -1)) return out @staticmethod - def run_cpu_gpu(X1, X2, out, kernel, dtype, options, diff): + def run_cpu_gpu( + X1, + X2, + out, + kernel, + dtype, + options, + diff, + kwargs_m1: Optional[Dict[str, torch.Tensor]], + kwargs_m2: Optional[Dict[str, torch.Tensor]], + ): gpu_info = _get_gpu_info(options, slack=options.memory_slack) block_sizes = calc_gpu_block_sizes(gpu_info, X1.shape[0]) args = [] # Arguments passed to each subprocess @@ -367,6 +497,9 @@ def run_cpu_gpu(X1, X2, out, kernel, dtype, options, diff): X1_block = X1.narrow_rows(block_sizes[i], bwidth) else: X1_block = X1.narrow(0, block_sizes[i], bwidth) + c_kwargs_m1 = {} + if kwargs_m1 is not None: + c_kwargs_m1 = {k: v[block_sizes[i] : block_sizes[i] + bwidth] for k, v in kwargs_m1} args.append( ( ArgsFmm( @@ -378,6 +511,8 @@ def run_cpu_gpu(X1, X2, out, kernel, dtype, options, diff): max_mem=g.usable_memory, num_streams=options.num_fmm_streams, differentiable=diff, + kwargs_m1=c_kwargs_m1, + kwargs_m2=kwargs_m2 or {}, ), g.Id, ) @@ -388,7 +523,17 @@ def run_cpu_gpu(X1, X2, out, kernel, dtype, options, diff): return call_outputs @staticmethod - def run_gpu_gpu(X1, X2, out, kernel, dtype, options, diff): + def run_gpu_gpu( + X1, + X2, + out, + kernel, + dtype, + options, + diff, + kwargs_m1: Optional[Dict[str, torch.Tensor]], + kwargs_m2: Optional[Dict[str, torch.Tensor]], + ): if isinstance(X1, SparseTensor): raise NotImplementedError("In-core, sparse fmm not implemented. Use the out-of-core version instead.") data_dev = X1.device @@ -403,6 +548,8 @@ def run_gpu_gpu(X1, X2, out, kernel, dtype, options, diff): max_mem=single_gpu_info.usable_memory, num_streams=options.num_fmm_streams, differentiable=diff, + kwargs_m1=kwargs_m1 or {}, + kwargs_m2=kwargs_m2 or {}, ) return _call_direct(mm_run_starter, (args, data_dev.index)) @@ -414,24 +561,30 @@ def run_diag( kernel: "falkon.kernels.Kernel", diff: bool, sparse: bool, + kwargs_m1: Optional[Dict[str, torch.Tensor]], + kwargs_m2: Optional[Dict[str, torch.Tensor]], ) -> torch.Tensor: + kwargs_m1 = kwargs_m1 or {} + kwargs_m2 = kwargs_m2 or {} if diff: - out_ = kernel.compute_diff(X1, X2, diag=True) + out_ = kernel.compute_diff(X1, X2, diag=True, **kwargs_m1, **kwargs_m2) if not out_.requires_grad: out_.requires_grad_() return out_.dot(out) else: if sparse: # TODO: This is likely to fail due to missing kwargs on distance_kernels - return kernel.compute_sparse(X1, X2, out, diag=True) + return kernel.compute_sparse(X1, X2, out, diag=True, **kwargs_m1, **kwargs_m2) else: - return kernel.compute(X1, X2, out, diag=True) + return kernel.compute(X1, X2, out, diag=True, **kwargs_m1, **kwargs_m2) @staticmethod def forward( ctx, kernel: "falkon.kernels.Kernel", opt: Optional[BaseOptions], + kwargs_m1: Optional[Dict[str, torch.Tensor]], + kwargs_m2: Optional[Dict[str, torch.Tensor]], out: Optional[torch.Tensor], diag: bool, X1: Union[torch.Tensor, SparseTensor], @@ -475,13 +628,13 @@ def forward( with torch.inference_mode(): if diag: - out = KernelMmFnFull.run_diag(X1, X2, out, kernel, False, is_sparse) + out = KernelMmFnFull.run_diag(X1, X2, out, kernel, False, is_sparse, kwargs_m1, kwargs_m2) elif comp_dev_type == "cpu" and data_dev.type == "cpu": - out = KernelMmFnFull.run_cpu_cpu(X1, X2, out, kernel, comp_dtype, opt, False) + out = KernelMmFnFull.run_cpu_cpu(X1, X2, out, kernel, comp_dtype, opt, False, kwargs_m1, kwargs_m2) elif comp_dev_type == "cuda" and data_dev.type == "cuda": - out = KernelMmFnFull.run_gpu_gpu(X1, X2, out, kernel, comp_dtype, opt, False) + out = KernelMmFnFull.run_gpu_gpu(X1, X2, out, kernel, comp_dtype, opt, False, kwargs_m1, kwargs_m2) elif comp_dev_type == "cuda" and data_dev.type == "cpu": - out = KernelMmFnFull.run_cpu_gpu(X1, X2, out, kernel, comp_dtype, opt, False) + out = KernelMmFnFull.run_cpu_gpu(X1, X2, out, kernel, comp_dtype, opt, False, kwargs_m1, kwargs_m2) else: raise RuntimeError("Requested CPU computations with CUDA data. This should not happen.") @@ -493,6 +646,8 @@ def forward( ctx.opt = opt ctx.comp_dtype = comp_dtype ctx.diag = diag + ctx.kwargs_m1 = kwargs_m1 + ctx.kwargs_m2 = kwargs_m2 return out @staticmethod @@ -507,13 +662,19 @@ def backward(ctx, outputs): with torch.autograd.enable_grad(): if ctx.diag: # TODO: Handle sparsity better - out = KernelMmFnFull.run_diag(X1, X2, outputs, ctx.kernel, True, sparse=False) + out = KernelMmFnFull.run_diag(X1, X2, outputs, ctx.kernel, True, False, ctx.kwargs_m1, ctx.kwargs_m2) elif comp_dev_type == "cpu" and data_dev.type == "cpu": - out = KernelMmFnFull.run_cpu_cpu(X1, X2, outputs, ctx.kernel, ctx.comp_dtype, ctx.opt, True) + out = KernelMmFnFull.run_cpu_cpu( + X1, X2, outputs, ctx.kernel, ctx.comp_dtype, ctx.opt, True, ctx.kwargs_m1, ctx.kwargs_m2 + ) elif comp_dev_type == "cuda" and data_dev.type == "cuda": - out = KernelMmFnFull.run_gpu_gpu(X1, X2, outputs, ctx.kernel, ctx.comp_dtype, ctx.opt, True) + out = KernelMmFnFull.run_gpu_gpu( + X1, X2, outputs, ctx.kernel, ctx.comp_dtype, ctx.opt, True, ctx.kwargs_m1, ctx.kwargs_m2 + ) elif comp_dev_type == "cuda" and data_dev.type == "cpu": - out = KernelMmFnFull.run_cpu_gpu(X1, X2, outputs, ctx.kernel, ctx.comp_dtype, ctx.opt, True) + out = KernelMmFnFull.run_cpu_gpu( + X1, X2, outputs, ctx.kernel, ctx.comp_dtype, ctx.opt, True, ctx.kwargs_m1, ctx.kwargs_m2 + ) else: raise RuntimeError("Requested CPU computations with CUDA data. This should not happen.") if isinstance(out, (tuple, list)): @@ -547,10 +708,12 @@ def fmm( diag: bool, X1: Union[torch.Tensor, SparseTensor], X2: Union[torch.Tensor, SparseTensor], + kwargs_m1: Optional[Dict[str, torch.Tensor]] = None, + kwargs_m2: Optional[Dict[str, torch.Tensor]] = None, ): import falkon.kernels if isinstance(kernel, falkon.kernels.DiffKernel): - return KernelMmFnFull.apply(kernel, opt, out, diag, X1, X2, *kernel.diff_params.values()) + return KernelMmFnFull.apply(kernel, opt, kwargs_m1, kwargs_m2, out, diag, X1, X2, *kernel.diff_params.values()) else: - return KernelMmFnFull.apply(kernel, opt, out, diag, X1, X2) + return KernelMmFnFull.apply(kernel, opt, kwargs_m1, kwargs_m2, out, diag, X1, X2) diff --git a/falkon/mmv_ops/fmmv.py b/falkon/mmv_ops/fmmv.py index 4ac3a54..797cf9e 100644 --- a/falkon/mmv_ops/fmmv.py +++ b/falkon/mmv_ops/fmmv.py @@ -1,7 +1,7 @@ from collections import defaultdict from contextlib import ExitStack -from dataclasses import dataclass -from typing import Optional, Tuple, Union +from dataclasses import dataclass, field +from typing import Optional, Tuple, Union, Dict import numpy as np import torch @@ -22,7 +22,12 @@ from falkon.options import BaseOptions from falkon.sparse import SparseTensor from falkon.utils.device_copy import copy -from falkon.utils.helpers import calc_gpu_block_sizes, select_dim_over_n, select_dim_over_nm_v2, sizeof_dtype +from falkon.utils.helpers import ( + calc_gpu_block_sizes, + select_dim_over_n, + select_dim_over_nm_v2, + sizeof_dtype, +) from falkon.utils.tensor_helpers import create_same_stride, extract_fortran @@ -36,6 +41,8 @@ class ArgsFmmv: max_mem: float w: torch.Tensor = None differentiable: bool = False + kwargs_m1: Dict[str, torch.Tensor] = field(default_factory=dict) + kwargs_m2: Dict[str, torch.Tensor] = field(default_factory=dict) def _init_two_streams( @@ -148,11 +155,39 @@ def mmv_run_starter(proc_idx, queue, device_id): ) if differentiable: assert not is_sparse, "Sparse + differentiable mmvs are not supported" - return mmv_diff_run_thread(X1, X2, v, out, kernel, blk_n, blk_m, dev, tid=proc_idx) + return mmv_diff_run_thread( + X1, X2, v, out, kernel, blk_n, blk_m, dev, tid=proc_idx, kwargs_m1=a.kwargs_m1, kwargs_m2=a.kwargs_m2 + ) if is_sparse: - return sparse_mmv_run_thread(X1, X2, v, out, kernel, blk_n, blk_m, mem_needed, dev, tid=proc_idx) + return sparse_mmv_run_thread( + X1, + X2, + v, + out, + kernel, + blk_n, + blk_m, + mem_needed, + dev, + tid=proc_idx, + kwargs_m1=a.kwargs_m1, + kwargs_m2=a.kwargs_m2, + ) else: - return mmv_run_thread(X1, X2, v, out, kernel, blk_n, blk_m, mem_needed, dev, tid=proc_idx) + return mmv_run_thread( + X1, + X2, + v, + out, + kernel, + blk_n, + blk_m, + mem_needed, + dev, + tid=proc_idx, + kwargs_m1=a.kwargs_m1, + kwargs_m2=a.kwargs_m2, + ) def sparse_mmv_run_thread( @@ -166,6 +201,8 @@ def sparse_mmv_run_thread( mem_needed: int, dev: torch.device, tid: int, + kwargs_m1: Dict[str, torch.Tensor], + kwargs_m2: Dict[str, torch.Tensor], ): """Inner loop to compute (part of) a kernel-vector product for sparse input matrices. @@ -193,6 +230,14 @@ def sparse_mmv_run_thread( Device on which to run the calculations tid Thread ID or -1 if on main thread + kwargs_m1 + Keyword arguments containing tensors which should be split along with ``m1``. + For example this could be a set of indices corresponding to ``m1``, which are then + correctly split and available in the kernel computation. + kwargs_m2 + Keyword arguments containing tensors which should be split along with ``m2``. + For example this could be a set of indices corresponding to ``m2``, which are then + correctly split and available in the kernel computation. Returns ------- @@ -218,6 +263,7 @@ def sparse_mmv_run_thread( s1, s2 = _init_two_streams(stack, dev, tid) # enters stream 1 for i in range(0, N, blk_n): leni = min(blk_n, N - i) + c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1} c_m1 = m1.narrow_rows(i, leni) if incore: # Note that CUDA-incore is not allowed to happen (so this is CPU->CPU) @@ -230,6 +276,7 @@ def sparse_mmv_run_thread( for j in range(0, M, blk_m): lenj = min(blk_m, M - j) + c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2} c_m2 = m2.narrow_rows(j, lenj) if incore: # CPU -> CPU @@ -247,7 +294,9 @@ def sparse_mmv_run_thread( c_dev_v = copy(v[j : j + lenj], dev_v[:lenj], non_blocking=True) c_dev_ker = ker_gpu[:leni, :lenj].fill_(0.0) - c_dev_ker = kernel.compute_sparse(c_dev_m1, c_dev_m2, c_dev_ker, diag=False, X1_csr=c_m1, X2_csr=c_m2) + c_dev_ker = kernel.compute_sparse( + c_dev_m1, c_dev_m2, c_dev_ker, diag=False, X1_csr=c_m1, X2_csr=c_m2, **c_kwargs_m1, **c_kwargs_m2 + ) if not incore: s2.synchronize() c_dev_out.addmm_(c_dev_ker, c_dev_v) @@ -273,6 +322,8 @@ def mmv_run_thread( mem_needed: int, dev: torch.device, tid: int, + kwargs_m1: Dict[str, torch.Tensor], + kwargs_m2: Dict[str, torch.Tensor], ): # data(CUDA), dev(CUDA) or data(CPU), dev(CPU) m1_ic, m2_ic, v_ic, out_ic = ( @@ -311,8 +362,7 @@ def mmv_run_thread( s1, s2 = _init_two_streams(stack, dev, tid) for i in range(0, N, blk_n): leni = min(blk_n, N - i) - # c_dev_m1 = m1[i: i + leni, :].to(dev, copy=False, non_blocking=True) - # c_dev_out = out[i: i + leni, :].to(dev, copy=False, non_blocking=True) + c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1} if m1_ic: c_dev_m1 = m1[i : i + leni, :] else: @@ -325,6 +375,7 @@ def mmv_run_thread( for j in range(0, M, blk_m): lenj = min(blk_m, M - j) + c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2} if m2_ic: c_dev_m2 = m2[j : j + lenj, :] else: @@ -337,13 +388,9 @@ def mmv_run_thread( if dev.type == "cuda": stack2.enter_context(tcd.stream(s2)) c_dev_v = copy(v[j : j + lenj, :], dev_v[:lenj, :], non_blocking=True) - # with ExitStack() as stack2: - # if dev.type == 'cuda': - # stack2.enter_context(tcd.stream(s2)) - # c_dev_v = v[j: j + lenj, :].to(dev, copy=False, non_blocking=True) c_dev_ker = dev_ker[:leni, :lenj].fill_(0.0) - c_dev_ker = kernel.compute(c_dev_m1, c_dev_m2, c_dev_ker, diag=False) + c_dev_ker = kernel.compute(c_dev_m1, c_dev_m2, c_dev_ker, diag=False, **c_kwargs_m1, **c_kwargs_m2) if not incore: s2.synchronize() c_dev_out.addmm_(c_dev_ker, c_dev_v) @@ -370,6 +417,8 @@ def mmv_diff_run_thread( blk_m: int, dev: torch.device, tid: int, + kwargs_m1: Dict[str, torch.Tensor], + kwargs_m2: Dict[str, torch.Tensor], ): # data(CUDA), dev(CUDA) or data(CPU), dev(CPU) incore = _is_incore(dev, m1.device) @@ -389,12 +438,14 @@ def mmv_diff_run_thread( s1, s2 = _init_two_streams(stack, dev, tid) for i in range(0, N, blk_n): leni = min(blk_n, N - i) + c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1} c_dev_m1 = m1[i : i + leni, :].to(dev, non_blocking=True, copy=False) c_dev_m1_g = None if grads[0] is None else grads[0][i : i + leni, :].to(dev, non_blocking=True, copy=False) c_dev_out = out[i : i + leni, :].to(dev, non_blocking=True, copy=False) for j in range(0, M, blk_m): lenj = min(blk_m, M - j) + c_kwargs_m2 = {k: v[j:lenj] for k, v in kwargs_m2} c_dev_m2 = m2[j : j + lenj, :].to(dev, non_blocking=True, copy=False) c_dev_m2_g = ( None if grads[1] is None else grads[1][j : j + lenj, :].to(dev, non_blocking=True, copy=False) @@ -406,7 +457,7 @@ def mmv_diff_run_thread( c_dev_v_g = ( None if grads[2] is None else grads[2][j : j + lenj, :].to(dev, non_blocking=True, copy=False) ) - c_dev_ker = kernel.compute_diff(c_dev_m1, c_dev_m2, diag=False) + c_dev_ker = kernel.compute_diff(c_dev_m1, c_dev_m2, diag=False, **c_kwargs_m1, **c_kwargs_m2) if not incore: s2.synchronize() # main MMV operation on current block @@ -414,7 +465,9 @@ def mmv_diff_run_thread( # Build inputs for torch.autograd.grad c_inputs = [c_dev_m1, c_dev_m2, c_dev_v] + list(kernel.diff_params.values()) c_dev_grads = torch.autograd.grad( - c_dev_mmv, [c_inputs[idx] for idx in input_idxs], grad_outputs=c_dev_out + c_dev_mmv, + [c_inputs[idx] for idx in input_idxs], + grad_outputs=c_dev_out, ) c_dev_grads_old = [c_dev_m1_g, c_dev_m2_g, c_dev_v_g] + grads[3:] for c_grad, c_idx in zip(c_dev_grads, input_idxs): @@ -527,9 +580,35 @@ def dmmv_run_starter(proc_idx, queue, device_id): ) if is_sparse: - sparse_dmmv_run_thread(X1, X2, v, w, out, kernel, blk_n, mem_needed, dev, tid=proc_idx) + sparse_dmmv_run_thread( + X1, + X2, + v, + w, + out, + kernel, + blk_n, + mem_needed, + dev, + tid=proc_idx, + kwargs_m1=a.kwargs_m1, + kwargs_m2=a.kwargs_m2, + ) else: - dmmv_run_thread(X1, X2, v, w, out, kernel, blk_n, mem_needed, dev, tid=proc_idx) + dmmv_run_thread( + X1, + X2, + v, + w, + out, + kernel, + blk_n, + mem_needed, + dev, + tid=proc_idx, + kwargs_m1=a.kwargs_m1, + kwargs_m2=a.kwargs_m2, + ) def sparse_dmmv_run_thread( @@ -543,6 +622,8 @@ def sparse_dmmv_run_thread( mem_needed: int, dev: torch.device, tid: int, + kwargs_m1: Dict[str, torch.Tensor], + kwargs_m2: Dict[str, torch.Tensor], ): incore = _is_incore(dev, m1.device) dev_out_exists = out.device == dev # out has already been allocated on the computation device @@ -579,6 +660,7 @@ def sparse_dmmv_run_thread( for i in range(0, N, blk_n): leni = min(blk_n, N - i) + c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1} c_m1 = m1.narrow_rows(i, leni) if incore: # Note that CUDA-incore is not allowed to happen (so this is CPU->CPU) @@ -591,7 +673,9 @@ def sparse_dmmv_run_thread( c_dev_w = copy(w[i : i + leni, :], dev_w[:leni, :], non_blocking=True) c_dev_ker = ker_gpu[:leni].fill_(0.0) - c_dev_ker = kernel.compute_sparse(c_dev_m1, dev_m2, c_dev_ker, diag=False, X1_csr=c_m1, X2_csr=m2) + c_dev_ker = kernel.compute_sparse( + c_dev_m1, dev_m2, c_dev_ker, diag=False, X1_csr=c_m1, X2_csr=m2, **c_kwargs_m1, **kwargs_m2 + ) c_dev_w.addmm_(c_dev_ker, dev_v) dev_out.addmm_(c_dev_ker.T, c_dev_w) @@ -612,6 +696,8 @@ def dmmv_run_thread( mem_needed: int, dev: torch.device, tid: int, + kwargs_m1: Dict[str, torch.Tensor], + kwargs_m2: Dict[str, torch.Tensor], ): # k(x2, x1) @ (k(x1, x2) @ v + w) # data(CUDA), dev(CUDA) or data(CPU), dev(CPU) @@ -658,6 +744,7 @@ def dmmv_run_thread( copy(v, dev_v, non_blocking=True) for i in range(0, N, blk_n): leni = min(blk_n, N - i) + c_kwargs_m1 = {k: v[i:leni] for k, v in kwargs_m1} if m1_ic: c_dev_m1 = m1[i : i + leni, :] else: @@ -671,7 +758,7 @@ def dmmv_run_thread( c_dev_w = dev_w[:leni, :].fill_(0.0) c_dev_ker = dev_ker[:leni, :].fill_(0.0) - c_dev_ker = kernel.compute(c_dev_m1, dev_m2, c_dev_ker, diag=False) + c_dev_ker = kernel.compute(c_dev_m1, dev_m2, c_dev_ker, diag=False, **c_kwargs_m1, **kwargs_m2) if s2 is not None: s2.synchronize() c_dev_w.addmm_(c_dev_ker, dev_v) @@ -696,8 +783,20 @@ def run_cpu_cpu( kernel, options, diff, + kwargs_m1: Optional[Dict[str, torch.Tensor]], + kwargs_m2: Optional[Dict[str, torch.Tensor]], ): - args = ArgsFmmv(X1=X1, X2=X2, v=v, out=out, kernel=kernel, max_mem=options.max_cpu_mem, differentiable=diff) + args = ArgsFmmv( + X1=X1, + X2=X2, + v=v, + out=out, + kernel=kernel, + max_mem=options.max_cpu_mem, + differentiable=diff, + kwargs_m1=kwargs_m1 or {}, + kwargs_m2=kwargs_m2 or {}, + ) return _call_direct(mmv_run_starter, (args, -1)) @staticmethod @@ -709,6 +808,8 @@ def run_cpu_gpu( kernel, options, diff, + kwargs_m1: Optional[Dict[str, torch.Tensor]], + kwargs_m2: Optional[Dict[str, torch.Tensor]], ): is_sparse = isinstance(X1, SparseTensor) gpu_info = _get_gpu_info(options, slack=options.memory_slack) @@ -722,6 +823,9 @@ def run_cpu_gpu( X1_block = X1.narrow_rows(block_sizes[i], bwidth) else: X1_block = X1.narrow(0, block_sizes[i], bwidth) + c_kwargs_m1 = {} + if kwargs_m1 is not None: + c_kwargs_m1 = {k: v[block_sizes[i] : block_sizes[i] + bwidth] for k, v in kwargs_m1} args.append( ( ArgsFmmv( @@ -732,6 +836,8 @@ def run_cpu_gpu( kernel=kernel, max_mem=g.usable_memory, differentiable=diff, + kwargs_m1=c_kwargs_m1, + kwargs_m2=kwargs_m2 or {}, ), g.Id, ) @@ -761,6 +867,8 @@ def run_gpu_gpu( kernel, options, diff, + kwargs_m1: Optional[Dict[str, torch.Tensor]], + kwargs_m2: Optional[Dict[str, torch.Tensor]], ): if isinstance(X1, SparseTensor): raise NotImplementedError("In-core, sparse fmmv not implemented. Use the out-of-core version instead.") @@ -768,7 +876,15 @@ def run_gpu_gpu( gpu_info = _get_gpu_info(options, slack=options.memory_slack) single_gpu_info = [g for g in gpu_info if g.Id == data_dev.index][0] args = ArgsFmmv( - X1=X1, X2=X2, v=v, out=out, kernel=kernel, max_mem=single_gpu_info.usable_memory, differentiable=diff + X1=X1, + X2=X2, + v=v, + out=out, + kernel=kernel, + max_mem=single_gpu_info.usable_memory, + differentiable=diff, + kwargs_m1=kwargs_m1 or {}, + kwargs_m2=kwargs_m2 or {}, ) return _call_direct(mmv_run_starter, (args, data_dev.index)) @@ -777,6 +893,8 @@ def forward( ctx, kernel: "falkon.kernels.Kernel", opt: Optional[BaseOptions], + kwargs_m1: Optional[Dict[str, torch.Tensor]], + kwargs_m2: Optional[Dict[str, torch.Tensor]], out: Optional[torch.Tensor], X1: Union[torch.Tensor, SparseTensor], X2: Union[torch.Tensor, SparseTensor], @@ -800,11 +918,11 @@ def forward( with torch.inference_mode(): if comp_dev_type == "cpu" and all(ddev.type == "cpu" for ddev in data_devs): - KernelMmvFnFull.run_cpu_cpu(X1, X2, v, out, kernel, opt, False) + KernelMmvFnFull.run_cpu_cpu(X1, X2, v, out, kernel, opt, False, kwargs_m1, kwargs_m2) elif comp_dev_type == "cuda" and all(ddev.type == "cuda" for ddev in data_devs): - KernelMmvFnFull.run_gpu_gpu(X1, X2, v, out, kernel, opt, False) + KernelMmvFnFull.run_gpu_gpu(X1, X2, v, out, kernel, opt, False, kwargs_m1, kwargs_m2) elif comp_dev_type == "cuda": - KernelMmvFnFull.run_cpu_gpu(X1, X2, v, out, kernel, opt, False) + KernelMmvFnFull.run_cpu_gpu(X1, X2, v, out, kernel, opt, False, kwargs_m1, kwargs_m2) else: raise RuntimeError("Requested CPU computations with CUDA data. This should not happen.") @@ -814,6 +932,8 @@ def forward( ctx.save_for_backward(X1, X2, v, *kernel_params) ctx.kernel = kernel ctx.opt = opt + ctx.kwargs_m1 = kwargs_m1 + ctx.kwargs_m2 = kwargs_m2 return out @staticmethod @@ -826,14 +946,20 @@ def backward(ctx, outputs): # We must rerun MM in differentiable mode this time. with torch.autograd.enable_grad(): if comp_dev_type == "cpu" and data_dev.type == "cpu": - grads = KernelMmvFnFull.run_cpu_cpu(X1, X2, v, outputs, ctx.kernel, ctx.opt, True) + grads = KernelMmvFnFull.run_cpu_cpu( + X1, X2, v, outputs, ctx.kernel, ctx.opt, True, ctx.kwargs_m1, ctx.kwargs_m2 + ) elif comp_dev_type == "cuda" and data_dev.type == "cuda": - grads = KernelMmvFnFull.run_gpu_gpu(X1, X2, v, outputs, ctx.kernel, ctx.opt, True) + grads = KernelMmvFnFull.run_gpu_gpu( + X1, X2, v, outputs, ctx.kernel, ctx.opt, True, ctx.kwargs_m1, ctx.kwargs_m2 + ) elif comp_dev_type == "cuda" and data_dev.type == "cpu": - grads = KernelMmvFnFull.run_cpu_gpu(X1, X2, v, outputs, ctx.kernel, ctx.opt, True) + grads = KernelMmvFnFull.run_cpu_gpu( + X1, X2, v, outputs, ctx.kernel, ctx.opt, True, ctx.kwargs_m1, ctx.kwargs_m2 + ) else: raise RuntimeError("Requested CPU computations with CUDA data. This should not happen.") - return tuple([None, None, None] + grads) + return tuple([None, None, None, None, None] + grads) def fmmv( @@ -843,11 +969,13 @@ def fmmv( kernel: "falkon.kernels.Kernel", out: Optional[torch.Tensor] = None, opt: Optional[BaseOptions] = None, + kwargs_m1: Optional[Dict[str, torch.Tensor]] = None, + kwargs_m2: Optional[Dict[str, torch.Tensor]] = None, ): if isinstance(kernel, falkon.kernels.DiffKernel): - return KernelMmvFnFull.apply(kernel, opt, out, X1, X2, v, *kernel.diff_params.values()) + return KernelMmvFnFull.apply(kernel, opt, kwargs_m1, kwargs_m2, out, X1, X2, v, *kernel.diff_params.values()) else: - return KernelMmvFnFull.apply(kernel, opt, out, X1, X2, v) + return KernelMmvFnFull.apply(kernel, opt, out, X1, X2, v, kwargs_m1, kwargs_m2) def fdmmv( @@ -859,6 +987,8 @@ def fdmmv( out: Optional[torch.Tensor] = None, differentiable: bool = False, opt: Optional[BaseOptions] = None, + kwargs_m1: Optional[Dict[str, torch.Tensor]] = None, + kwargs_m2: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: r"""Double kernel-vector product @@ -885,6 +1015,14 @@ def fdmmv( to ``True`` results in a :code:`NotImplementedError`. opt Options to be used for this operation + kwargs_m1 + Keyword arguments containing tensors which should be split along with ``m1``. + For example this could be a set of indices corresponding to ``m1``, which are then + correctly split and available in the kernel computation. + kwargs_m2 + Keyword arguments containing tensors which should be split along with ``m2``. + For example this could be a set of indices corresponding to ``m2``, which are then + correctly split and available in the kernel computation. Returns ------- @@ -914,7 +1052,17 @@ def fdmmv( ) if comp_dev_type == "cpu" and all(ddev.type == "cpu" for ddev in data_devs): - args = ArgsFmmv(X1=X1, X2=X2, v=v, w=w, out=out, kernel=kernel, max_mem=opt.max_cpu_mem) + args = ArgsFmmv( + X1=X1, + X2=X2, + v=v, + w=w, + out=out, + kernel=kernel, + max_mem=opt.max_cpu_mem, + kwargs_m1=kwargs_m1 or {}, + kwargs_m2=kwargs_m2 or {}, + ) _call_direct(dmmv_run_starter, (args, -1)) elif comp_dev_type == "cuda" and all(ddev.type == "cuda" for ddev in data_devs): if is_sparse: @@ -922,7 +1070,17 @@ def fdmmv( gpu_info = _get_gpu_info(opt, slack=opt.memory_slack) data_dev = data_devs[0] single_gpu_info = [g for g in gpu_info if g.Id == data_dev.index][0] - args = ArgsFmmv(X1=X1, X2=X2, v=v, w=w, out=out, kernel=kernel, max_mem=single_gpu_info.usable_memory) + args = ArgsFmmv( + X1=X1, + X2=X2, + v=v, + w=w, + out=out, + kernel=kernel, + max_mem=single_gpu_info.usable_memory, + kwargs_m1=kwargs_m1 or {}, + kwargs_m2=kwargs_m2 or {}, + ) _call_direct(dmmv_run_starter, (args, data_dev.index)) elif comp_dev_type == "cuda": gpu_info = _get_gpu_info(opt, slack=opt.memory_slack) @@ -943,6 +1101,9 @@ def fdmmv( X1_block = X1.narrow_rows(block_sizes[i], bwidth) else: X1_block = X1.narrow(0, block_sizes[i], bwidth) + c_kwargs_m1 = {} + if kwargs_m1 is not None: + c_kwargs_m1 = {k: v[block_sizes[i] : block_sizes[i] + bwidth] for k, v in kwargs_m1} args.append( ( ArgsFmmv( @@ -953,6 +1114,8 @@ def fdmmv( out=cur_out_gpu, kernel=kernel, max_mem=g.usable_memory, + kwargs_m1=c_kwargs_m1, + kwargs_m2=kwargs_m2 or {}, ), g.Id, ) @@ -961,7 +1124,10 @@ def fdmmv( if len(wrlk) > 1: # Sum up all subprocess outputs and copy to `out` on host. # noinspection PyTypeChecker fastest_device: int = np.argmax([d.speed for d in gpu_info]) - copy(torch.cuda.comm.reduce_add(wrlk, destination=gpu_info[fastest_device].Id), out) + copy( + torch.cuda.comm.reduce_add(wrlk, destination=gpu_info[fastest_device].Id), + out, + ) else: if wrlk[0].data_ptr() != out.data_ptr(): copy(wrlk[0], out) diff --git a/pyproject.toml b/pyproject.toml index 3ff25a4..6911acc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,9 @@ omit = [ "falkon/tests/*", "falkon/hopt/*", "falkon/benchmarks/*", "falkon/csrc/*", ] +[tool.black] +line-length = 120 + [tool.ruff] target-version = "py38" ignore = [