diff --git a/docs/source/conf.py b/docs/source/conf.py index 4d66c50e..e5d90512 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -116,14 +116,20 @@ def _dim_to_str(dim): if isinstance(dim, jaxtyping.array_types._NamedVariadicDim): return "..." elif isinstance(dim, jaxtyping.array_types._FixedDim): - return str(dim.size) + res = str(dim.size) + if dim.broadcastable: + res = "#" + res + return res elif isinstance(dim, jaxtyping.array_types._SymbolicDim): expr = code_deparse(dim.expr).text.strip().split("return ")[1] return f"({expr})" elif "jaxtyping" not in str(dim.__class__): # Probably the case that we have an ellipsis return "..." else: - return str(dim.name) + res = str(dim.name) + if dim.broadcastable: + res = "#" + res + return res # Function to format type hints @@ -152,9 +158,15 @@ def _process(annotation, config): elif hasattr(annotation, "__name__"): res = _convert_internal_and_external_class_to_strings(annotation) + elif str(annotation).startswith("typing.Callable"): + if len(annotation.__args__) == 2: + res = f"Callable[{_process(annotation.__args__[0], config)} -> {_process(annotation.__args__[1], config)}]" + else: + res = "Callable" + # Convert any Union[*A*, *B*, *C*] into "*A* or *B* or *C*" # Also, convert any Optional[*A*] into "*A*, optional" - elif "typing.Union" in str(annotation): + elif str(annotation).startswith("typing.Union"): is_optional_str = "" args = list(annotation.__args__) # Hack: Optional[*A*] are represented internally as Union[*A*, Nonetype] @@ -166,13 +178,13 @@ def _process(annotation, config): res = " or ".join(processed_args) + is_optional_str # Convert any Tuple[*A*, *B*] into "(*A*, *B*)" - elif "typing.Tuple" in str(annotation): + elif str(annotation).startswith("typing.Tuple"): args = list(annotation.__args__) res = "(" + ", ".join(_process(arg, config) for arg in args) + ")" # Callable typing annotation - elif "typing." in str(annotation): - return str(annotation) + elif str(annotation).startswith("typing."): + return str(annotation)[7:] # Special cases for forward references. # This is brittle, as it only contains case for a select few forward refs diff --git a/docs/source/data_sparse_operators.rst b/docs/source/data_sparse_operators.rst index d580450f..34c5b6e1 100644 --- a/docs/source/data_sparse_operators.rst +++ b/docs/source/data_sparse_operators.rst @@ -36,6 +36,12 @@ Data-Sparse LinearOperators .. autoclass:: linear_operator.operators.IdentityLinearOperator :members: +:hidden:`KernelLinearOperator` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: linear_operator.operators.KernelLinearOperator + :members: + :hidden:`RootLinearOperator` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/linear_operator/operators/__init__.py b/linear_operator/operators/__init__.py index 79e92108..ec95aa65 100644 --- a/linear_operator/operators/__init__.py +++ b/linear_operator/operators/__init__.py @@ -14,6 +14,7 @@ from .identity_linear_operator import IdentityLinearOperator from .interpolated_linear_operator import InterpolatedLinearOperator from .keops_linear_operator import KeOpsLinearOperator +from .kernel_linear_operator import KernelLinearOperator from .kronecker_product_added_diag_linear_operator import KroneckerProductAddedDiagLinearOperator from .kronecker_product_linear_operator import ( KroneckerProductDiagLinearOperator, @@ -53,6 +54,7 @@ "IdentityLinearOperator", "InterpolatedLinearOperator", "KeOpsLinearOperator", + "KernelLinearOperator", "KroneckerProductLinearOperator", "KroneckerProductAddedDiagLinearOperator", "KroneckerProductDiagLinearOperator", diff --git a/linear_operator/operators/_linear_operator.py b/linear_operator/operators/_linear_operator.py index e558ae9e..f4c01fd2 100644 --- a/linear_operator/operators/_linear_operator.py +++ b/linear_operator/operators/_linear_operator.py @@ -3,10 +3,12 @@ from __future__ import annotations import functools +import itertools import math import numbers import warnings from abc import abstractmethod +from collections import OrderedDict from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -150,7 +152,14 @@ def __init__(self, *args, **kwargs): raise ValueError(err) self._args = args - self._kwargs = kwargs + self._differentiable_kwargs = OrderedDict() + self._nondifferentiable_kwargs = dict() + for name, val in sorted(kwargs.items()): + # Sorting is necessary so that the flattening in the representation tree is deterministic + if torch.is_tensor(val) or isinstance(val, LinearOperator): + self._differentiable_kwargs[name] = val + else: + self._nondifferentiable_kwargs[name] = val #### # The following methods need to be defined by the LinearOperator @@ -350,17 +359,24 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O """ from collections import deque - args = tuple(self.representation()) - args_with_grads = tuple(arg for arg in args if arg.requires_grad) + # Construct a detached version of each argument in the linear operator + args = [] + for arg in self.representation(): + # All arguments here are guaranteed to be tensors + if arg.dtype.is_floating_point and arg.requires_grad: + args.append(arg.detach().requires_grad_(True)) + else: + args.append(arg.detach()) - # Easy case: if we don't require any gradients, then just return! - if not len(args_with_grads): - return tuple(None for _ in args) + # If no arguments require gradients, then we're done! + if not any(arg.requires_grad for arg in args): + return (None,) * len(args) - # Normal case: we'll use the autograd to get us a derivative + # We'll use the autograd to get us a derivative with torch.autograd.enable_grad(): - loss = (left_vecs * self._matmul(right_vecs)).sum() - loss.requires_grad_(True) + lin_op = self.representation_tree()(*args) + loss = (left_vecs * lin_op._matmul(right_vecs)).sum() + args_with_grads = [arg for arg in args if arg.requires_grad] actual_grads = deque(torch.autograd.grad(loss, args_with_grads, allow_unused=True)) # Now make sure that the object we return has one entry for every item in args @@ -457,6 +473,10 @@ def _args(self) -> Tuple[Union[torch.Tensor, "LinearOperator", int], ...]: def _args(self, args: Tuple[Union[torch.Tensor, "LinearOperator", int], ...]) -> None: self._args_memo = args + @property + def _kwargs(self) -> Dict[str, Any]: + return {**self._differentiable_kwargs, **self._nondifferentiable_kwargs} + def _approx_diagonal(self: Float[LinearOperator, "*batch N N"]) -> Float[torch.Tensor, "*batch N"]: """ (Optional) returns an (approximate) diagonal of the matrix @@ -1344,7 +1364,11 @@ def detach(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, " (In practice, this function removes all Tensors that make up the :obj:`~linear_operator.opeators.LinearOperator` from the computation graph.) """ - return self.clone().detach_() + detached_args = [arg.detach() if hasattr(arg, "detach") else arg for arg in self._args] + detached_kwargs = dict( + (key, val.detach() if hasattr(val, "detach") else val) for key, val in self._kwargs.items() + ) + return self.__class__(*detached_args, **detached_kwargs) def detach_(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]: """ @@ -2013,7 +2037,7 @@ def representation(self) -> Tuple[torch.Tensor, ...]: Returns the Tensors that are used to define the LinearOperator """ representation = [] - for arg in self._args: + for arg in itertools.chain(self._args, self._differentiable_kwargs.values()): if torch.is_tensor(arg): representation.append(arg) elif hasattr(arg, "representation") and callable(arg.representation): # Is it a LinearOperator? diff --git a/linear_operator/operators/keops_linear_operator.py b/linear_operator/operators/keops_linear_operator.py index 44fddec5..6990b1f2 100644 --- a/linear_operator/operators/keops_linear_operator.py +++ b/linear_operator/operators/keops_linear_operator.py @@ -1,5 +1,7 @@ from __future__ import annotations +import warnings + from typing import Optional, Tuple, Union import torch @@ -13,6 +15,10 @@ class KeOpsLinearOperator(LinearOperator): def __init__(self, x1, x2, covar_func, **params): + warnings.warn( + "KeOpsLinearOperator is deprecated. Please use KernelLinearOperator instead.", + DeprecationWarning, + ) super().__init__(x1, x2, covar_func=covar_func, **params) self.x1 = x1.contiguous() diff --git a/linear_operator/operators/kernel_linear_operator.py b/linear_operator/operators/kernel_linear_operator.py new file mode 100644 index 00000000..1df031bc --- /dev/null +++ b/linear_operator/operators/kernel_linear_operator.py @@ -0,0 +1,425 @@ +from collections import defaultdict +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch + +from jaxtyping import Float +from torch import Tensor + +from ..utils.broadcasting import _pad_with_singletons +from ..utils.getitem import _noop_index, IndexType +from ..utils.memoize import cached +from ._linear_operator import LinearOperator, to_dense + + +def _x_getitem(x, batch_indices, data_index): + """ + Helper function to compute x[*batch_indices, data_index, :] in an efficient way. + (Sometimes x needs to be expanded before calling x[*batch_indices, data_index, :]; i.e. if + the batch_indices broadcast. We try to prevent this expansion if possible. + """ + try: + x = x[(*batch_indices, data_index, _noop_index)] + # We're going to handle multi-batch indexing with a try-catch loop + # This way - in the default case, we can avoid doing expansions of x1 which can be timely + except IndexError: + if isinstance(batch_indices, slice): + x = x.expand(1, *x.shape[-2:]) + x = [(*batch_indices, data_index, _noop_index)] + elif isinstance(batch_indices, tuple): + if any(not isinstance(bi, slice) for bi in batch_indices): + raise RuntimeError( + "Attempting to tensor index a non-batch matrix's batch dimensions. " + f"Got batch index {batch_indices} but my shape was {x.shape}" + ) + x = x.expand(*([1] * len(batch_indices)), *x.shape[-2:]) + x = x[(*batch_indices, data_index, _noop_index)] + return x + + +class KernelLinearOperator(LinearOperator): + r""" + Represents the kernel matrix :math:`\boldsymbol K` + of data :math:`\boldsymbol X_1 \in \mathbb R^{M \times D}` + and :math:`\boldsymbol X_2 \in \mathbb R^{N \times D}` + under the covariance function :math:`k_{\boldsymbol \theta}(\cdot, \cdot)` + (parameterized by hyperparameters :math:`\boldsymbol \theta` + so that :math:`\boldsymbol K_{ij} = k_{\boldsymbol \theta}([\boldsymbol X_1]_i, [\boldsymbol X_2]_j)`. + + The output of :math:`k_{\boldsymbol \theta}(\cdot,\cdot)` (`covar_func`) can either be a torch.Tensor + or a LinearOperator. + + .. note :: + + All hyperparameters have some number of batch dimensions (which broadcast with the + batch dimensions of x1 and x2) and some number of non-batch dimensions + (dimensions that would exist if we were computing a single covariance matrix). + + By default, each hyperparameter is assumed to have 2 (potentially singleton) non-batch + dimensions. However, the number of non_batch dimensions can be specified on a + per-hyperparameter through the optional `num_nonbatch_dimensions` dictionary argument. + + For example, to implement the RBF kernel + + .. math:: + + o^2 \exp\left( + -\tfrac{1}{2} (\boldsymbol x_1 - \boldsymbol x2)^\top \boldsymbol D_\ell^{-2} + (\boldsymbol x_1 - \boldsymbol x2) + \right), + + where :math:`o` is an `outputscale` parameter and :math:`D_\ell` is a diagonal `lengthscale` matrix, + we would expect the following shapes: + + - `x1`: `(*batch_shape x N x D)` + - `x2`: `(*batch_shape x M x D)` + - `lengthscale`: `(*batch_shape x 1 x D)` + - `outputscale`: `(*batch_shape)` # Note this parameter does not have non-batch dimensions + + We would then supply the dictionary `num_nonbatch_dimensions = {"outputscale": 0}`. + (We do not need to include lengthscale in the dictionary since it has 2 non-batch dimensions.) + + .. code-block:: python + + # NOTE: _covar_func intentionally does not close over any parameters + def _covar_func(x1, x2, lengthscale, outputscale): + # RBF kernel function + # x1: ... x N x D + # x2: ... x M x D + # lengthscale: ... x 1 x D + # outputscale: ... + x1 = x1.div(lengthscale) + x2 = x2.div(lengthscale) + sq_dist = (x1.unsqueeze(-2) - x2.unsqueeze(-3)).square().sum(dim=-1) + kern = sq_dist.div(-2.0).exp().mul(outputscale[..., None, None].square()) + return kern + + + # Batches of data + x1 = torch.randn(3, 5, 6) + x2 = torch.randn(3, 4, 6) + # Broadcasting lengthscale and output parameters + lengthscale = torch.randn(2, 1, 1, 6) # Batch shape is 2 x 1, with 2 non-batch dimensions + outputscale = torch.randn(2, 1) # Batch shape is 2 x 1, no non-batch dimensions + kern = KernelLinearOperator( + x1, x2, lengthscale=lengthscale, outputscale=outputscale, + covar_func=covar_func, num_nonbatch_dimensions={"outputscale": 0} + ) + + # kern is of size 2 x 3 x 5 x 4 + + .. warning :: + + `covar_func` should not close over any parameters. Any parameters that are closed over will not have + propagated gradients. + + See the example above: the lengthscale and outputscale of _covar_func are passed in as arguments, + rather than being externally defined variables. + + :param x1: The data :math:`\boldsymbol X_1.` + :param x2: The data :math:`\boldsymbol X_2.` + :param covar_func: The covariance function :math:`k_{\boldsymbol \theta}(\cdot, \cdot)`. + Its arguments should be `x1`, `x2`, `**params`, and it should output the covariance matrix + between :math:`\boldsymbol X_1` and :math:`\boldsymbol X_2`. + :param num_outputs_per_input: The number of outputs per data point. + This parameter should be 1 for most kernels, but will be >1 for multitask kernels, + gradient kernels, and any other kernels that require cross-covariance terms for multiple domains. + If a tuple is passed, there will be a different number of outputs per input dimension + for the rows/cols of the kernel matrix. + :param params: Additional hyperparameters (:math:`\boldsymbol \theta`) or keyword arguments passed into covar_func. + """ + + def __init__( + self, + x1: Float[Tensor, "... M D"], + x2: Float[Tensor, "... N D"], + covar_func: Callable[..., Float[Union[Tensor, LinearOperator], "... M N"]], + num_outputs_per_input: Tuple[int, int] = (1, 1), + num_nonbatch_dimensions: Optional[Dict[str, int]] = None, + **params: Union[Tensor, Any], + ): + # Change num_nonbatch_dimensions into a default dict + if num_nonbatch_dimensions is None: + num_nonbatch_dimensions = defaultdict(lambda: 2) + else: + num_nonbatch_dimensions = defaultdict(lambda: 2, **num_nonbatch_dimensions) + + # Divide params into tensors and non-tensors + tensor_params = dict() + nontensor_params = dict() + for name, val in params.items(): + if torch.is_tensor(val): + tensor_params[name] = val + else: + nontensor_params[name] = val + + # Compute param_batch_shapes + param_batch_shapes = dict() + param_nonbatch_shapes = dict() + for name, val in tensor_params.items(): + if num_nonbatch_dimensions[name] == 0: + param_batch_shapes[name] = val.shape + param_nonbatch_shapes[name] = torch.Size([]) + else: + nonbatch_dim = num_nonbatch_dimensions[name] + param_batch_shapes[name] = val.shape[:-nonbatch_dim] + param_nonbatch_shapes[name] = val.shape[-nonbatch_dim:] + + # Ensure that x1, x2, and params can broadcast together + try: + batch_broadcast_shape = torch.broadcast_shapes(x1.shape[:-2], x2.shape[:-2], *param_batch_shapes.values()) + except RuntimeError: + # Check if the issue is with x1 and x2 + try: + x1_nodata_shape = torch.Size([*x1.shape[:-2], 1, x1.shape[-1]]) + x2_nodata_shape = torch.Size([*x2.shape[:-2], 1, x2.shape[-1]]) + torch.broadcast_shapes(x1_nodata_shape, x2_nodata_shape) + except RuntimeError: + raise RuntimeError( + "Incompatible data shapes for a kernel matrix: " + f"x1.shape={tuple(x1.shape)}, x2.shape={tuple(x2.shape)}." + ) + + # If we've made here, this means that the parameter shapes aren't compatible with x1 and x2 + raise RuntimeError( + "Shape of kernel parameters " + f"({', '.join([str(tuple(param.shape)) for param in tensor_params.values()])}) " + f"is incompatible with data shapes x1.shape={tuple(x1.shape)}, x2.shape={tuple(x2.shape)}.\n" + "Recall that parameters passed to KernelLinearOperator should have dimensionality compatible " + "with the data (see documentation)." + ) + + # Create a version of each argument that is expanded to the broadcast batch shape + # + # NOTE: we must explicitly call requires_grad on each of these arguments + # for the automatic _bilinear_derivative to work in torch.autograd.Functions + if len(batch_broadcast_shape): # Otherwise all tensors are non-batch, and we don't need to expand + x1 = x1.expand(*batch_broadcast_shape, *x1.shape[-2:]).contiguous().requires_grad_(x1.requires_grad) + x2 = x2.expand(*batch_broadcast_shape, *x2.shape[-2:]).contiguous().requires_grad_(x2.requires_grad) + tensor_params = { + name: val.expand(*batch_broadcast_shape, *param_nonbatch_shapes[name]).requires_grad_(val.requires_grad) + for name, val in tensor_params.items() + } + # Everything should now have the same batch shape + + # Standard constructor + super().__init__( + x1, + x2, + covar_func=covar_func, + num_outputs_per_input=num_outputs_per_input, + num_nonbatch_dimensions=num_nonbatch_dimensions, + **tensor_params, + **nontensor_params, + ) + self.batch_broadcast_shape = batch_broadcast_shape + self.x1 = x1 + self.x2 = x2 + self.tensor_params = tensor_params + self.nontensor_params = nontensor_params + self.covar_func = covar_func + self.num_outputs_per_input = num_outputs_per_input + self.num_nonbatch_dimensions = num_nonbatch_dimensions + + @cached(name="kernel_diag") + def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + # Explicitly compute kernel diag via covar_func when it is needed rather than relying on lazy tensor ops. + # We will do this by shoving all of the data into a batch dimension (i.e. compute a N x ... x 1 x 1 kernel + # or a N x ... x num_outs-per_in x num_outs_per_in kernel) + # and then squeeze out the batch dimensions + x1 = self.x1.unsqueeze(0).transpose(0, -2) + x2 = self.x2.unsqueeze(0).transpose(0, -2) + tensor_params = {name: val.unsqueeze(0) for name, val in self.tensor_params.items()} + diag_mat = to_dense(self.covar_func(x1, x2, **tensor_params, **self.nontensor_params)) + assert diag_mat.shape[-2:] == torch.Size(self.num_outputs_per_input) + + # Easy case: the kernel only has one output per input (standard kernels) + if self.num_outputs_per_input == (1, 1): + return diag_mat.transpose(0, -2)[0, ..., 0] + # Complicated case: the kernel only has multiple output per input (e.g. multitask kernels) + else: + # First: reshape the matrix to be ... x N x num_outputs_per_input x num_outputs_per_input + diag_mat = diag_mat.permute(*range(1, diag_mat.dim() - 2), 0, -2, -1) + # Next: get the diagonal vector, so that we have ... x N x num_outputs_per_input + unflattened_diag = diag_mat.diagonal(dim1=-1, dim2=-2) + # Finally: flatten the diagonal vector, so that we have ... x (N * num_outputs_per_input) + return unflattened_diag.reshape(*unflattened_diag.shape[:-2], -1) + + @property + @cached(name="covar_mat") + def covar_mat(self: Float[LinearOperator, "... M N"]) -> Float[Union[Tensor, LinearOperator], "... M N"]: + return self.covar_func(self.x1, self.x2, **self.tensor_params, **self.nontensor_params) + + def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor: + # Similar to diagonal will do this by shoving all of the data into a batch dimension + # (i.e. compute a N x ... x 1 x 1 kernel or a N x ... x num_outs_per_in x num_outs_per_in kernel) + # and then squeeze out the batch dimensions + num_outs_per_in_rows, num_outs_per_in_cols = self.num_outputs_per_input + x1_ = self.x1[(*batch_indices, row_index.div(num_outs_per_in_rows, rounding_mode="floor"))].unsqueeze( + -2 + ) # x1 will have shape ... x 1 x 1 + x2_ = self.x2[(*batch_indices, col_index.div(num_outs_per_in_rows, rounding_mode="floor"))].unsqueeze( + -2 + ) # x2 will have shape ... x 1 x 1 + tensor_params_ = {name: val[batch_indices] for name, val in self.tensor_params.items()} # will have shape ... + indices_mat = to_dense(self.covar_func(x1_, x2_, **tensor_params_, **self.nontensor_params)) + assert indices_mat.shape[-2:] == torch.Size(self.num_outputs_per_input) + # Easy case: the kernel only has one output per input (standard kernels) + if self.num_outputs_per_input == (1, 1): + return indices_mat[..., 0, 0] + # Complicated case: the kernel only has multiple output per input (e.g. multitask kernels) + else: + # The current shape of indices mat is ... x num_outs_per_in_row x num_outs_per_in_col + # And we want the final shape to be ... + # Therefore, figure out which of outputs we want to keep + row_output_index = row_index % num_outs_per_in_rows + col_output_index = col_index % num_outs_per_in_cols + # Now we select those specific outputs + # We neeed iterative tensors to select the appropriate elements from the batch dimensions + # of indices_mat + batch_indices = [ + _pad_with_singletons( + torch.arange(size, device=indices_mat.device), + num_singletons_before=i, + num_singletons_after=(indices_mat.dim() - 3 - i), + ) + for i, size in enumerate(indices_mat.shape[:-2]) + ] + return indices_mat[(*batch_indices, row_output_index, col_output_index)] + + def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> LinearOperator: + # If we have multiple outputs per input, then the indices won't directly + # correspond to the entries of row/col. We'll have to do a little pre-processing + num_outs_per_in_rows, num_outs_per_in_cols = self.num_outputs_per_input + if num_outs_per_in_rows != 1 or num_outs_per_in_cols != 1: + if not isinstance(row_index, slice) or not isinstance(col_index, slice): + # It's too complicated to deal with tensor indices in this case - we'll use the super method + try: + return self.covar_mat._getitem(row_index, col_index, *batch_indices) + except Exception: + raise TypeError( + f"{self.__class__.__name__} does not accept non-slice indices. " + f"Got {','.join(type(t) for t in [*batch_indices, row_index, col_index])}" + ) + + # Now we know that x1 and x2 are slices + # Let's make sure that the slice dimensions perfectly correspond with the number of + # outputs per input that we have + *batch_shape, num_rows, num_cols = self._size() + row_start, row_end, row_step = ( + row_index.start if row_index.start is not None else 0, + row_index.stop if row_index.stop is not None else num_rows, + row_index.step if row_index.step is not None else 1, + ) + col_start, col_end, col_step = ( + col_index.start if col_index.start is not None else 0, + col_index.stop if col_index.stop is not None else num_cols, + col_index.step if col_index.step is not None else 1, + ) + if row_step is not None or col_step is not None: + # It's too complicated to deal with tensor indices in this case - we'll try to evaluate the kernel + # and use the super method + try: + return self.covar_mat._getitem(row_index, col_index, *batch_indices) + except Exception: + raise TypeError(f"{self.covar_mat.__class__.__name__} does not accept slices with steps.") + if ( + (row_start % num_outs_per_in_rows) + or (col_start % num_outs_per_in_cols) + or (row_end % num_outs_per_in_rows) + or (col_end % num_outs_per_in_cols) + ): + # It's too complicated to deal with tensor indices in this case - we'll try to evaluate the kernel + # and use the super method + try: + return self.covar_mat._getitem(row_index, col_index, *batch_indices) + except Exception: + raise TypeError( + f"{self.covar_mat.__class__.__name__} received an invalid slice. " + "Since the covariance function produces multiple outputs for input, the slice " + "should perfectly correspond with the number of outputs per input." + ) + + # Otherwise - let's divide the slices by the number of outputs per input + row_index = slice(row_start // num_outs_per_in_rows, row_end // num_outs_per_in_rows, None) + col_index = slice(col_start // num_outs_per_in_cols, col_end // num_outs_per_in_cols, None) + + # Get the indices of x1 and x2 that matter for the kernel + # Call x1[*batch_indices, row_index, :] and x2[*batch_indices, col_index, :] + x1 = _x_getitem(self.x1, batch_indices, row_index) + x2 = _x_getitem(self.x2, batch_indices, col_index) + + # Call params[*batch_indices, :, :] + tensor_params = { + name: val[(*batch_indices, *([_noop_index] * self.num_nonbatch_dimensions[name]))] + for name, val in self.tensor_params.items() + } + + # Now construct a kernel with those indices + return self.__class__( + x1, + x2, + covar_func=self.covar_func, + num_outputs_per_input=self.num_outputs_per_input, + num_nonbatch_dimensions=self.num_nonbatch_dimensions, + **tensor_params, + **self.nontensor_params, + ) + + def _matmul( + self: Float[LinearOperator, "*batch M N"], + rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], + ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + return self.covar_mat @ rhs.contiguous() + + def _permute_batch(self, *dims: int) -> LinearOperator: + x1 = self.x1.permute(*dims, -2, -1) + x2 = self.x2.permute(*dims, -2, -1) + tensor_params = { + name: val.permute(*dims, *range(-self.num_nonbatch_dimensions[name], 0)) + for name, val in self.tensor_params.items() + } + return self.__class__( + x1, + x2, + covar_func=self.covar_func, + num_outputs_per_input=self.num_outputs_per_input, + num_nonbatch_dimensions=self.num_nonbatch_dimensions, + **tensor_params, + **self.nontensor_params, + ) + + def _size(self) -> torch.Size: + num_outs_per_in_rows, num_outs_per_in_cols = self.num_outputs_per_input + return torch.Size( + [ + *self.batch_broadcast_shape, + self.x1.shape[-2] * num_outs_per_in_rows, + self.x2.shape[-2] * num_outs_per_in_cols, + ] + ) + + def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + return self.__class__( + self.x2, + self.x1, + covar_func=self.covar_func, + num_outputs_per_input=self.num_outputs_per_input, + num_nonbatch_dimensions=self.num_nonbatch_dimensions, + **self.tensor_params, + **self.nontensor_params, + ) + + def _unsqueeze_batch(self, dim: int) -> LinearOperator: + x1 = self.x1.unsqueeze(dim) + x2 = self.x2.unsqueeze(dim) + tensor_params = {name: val.unsqueeze(dim) for name, val in self.tensor_params.items()} + return self.__class__( + x1, + x2, + covar_func=self.covar_func, + num_outputs_per_input=self.num_outputs_per_input, + num_nonbatch_dimensions=self.num_nonbatch_dimensions, + **tensor_params, + **self.nontensor_params, + ) diff --git a/linear_operator/operators/kronecker_product_linear_operator.py b/linear_operator/operators/kronecker_product_linear_operator.py index 3c2617de..bcd65321 100644 --- a/linear_operator/operators/kronecker_product_linear_operator.py +++ b/linear_operator/operators/kronecker_product_linear_operator.py @@ -62,23 +62,36 @@ def _t_matmul(linear_ops, kp_shape, rhs): class KroneckerProductLinearOperator(LinearOperator): r""" - Returns the Kronecker product of the given lazy tensors + Given linearOperators :math:`\boldsymbol K_1, \ldots, \boldsymbol K_P`, + this LinearOperator represents the Kronecker product :math:`\boldsymbol K_1 \otimes \ldots \otimes \boldsymbol K_P`. - Args: - :`linear_ops`: List of lazy tensors + :param linear_ops: :math:`\boldsymbol K_1, \ldots, \boldsymbol K_P`: the LinearOperators in the Kronecker product. """ - def __init__(self, *linear_ops): + def __init__(self, *linear_ops: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"]]): try: linear_ops = tuple(to_linear_operator(linear_op) for linear_op in linear_ops) except TypeError: raise RuntimeError("KroneckerProductLinearOperator is intended to wrap lazy tensors.") - for prev_linear_op, curr_linear_op in zip(linear_ops[:-1], linear_ops[1:]): - if prev_linear_op.batch_shape != curr_linear_op.batch_shape: - raise RuntimeError( - "KroneckerProductLinearOperator expects lazy tensors with the " - "same batch shapes. Got {}.".format([lv.batch_shape for lv in linear_ops]) - ) + + # Make batch shapes the same for all operators + try: + batch_broadcast_shape = torch.broadcast_shapes(*(linear_op.batch_shape for linear_op in linear_ops)) + except RuntimeError: + raise RuntimeError( + "Batch shapes of LinearOperators " + f"({', '.join([str(tuple(linear_op.shape)) for linear_op in linear_ops])}) " + "are incompatible for a Kronecker product." + ) + + if len(batch_broadcast_shape): # Otherwise all linear_ops are non-batch, and we don't need to expand + # NOTE: we must explicitly call requires_grad on each of these arguments + # for the automatic _bilinear_derivative to work in torch.autograd.Functions + linear_ops = [ + linear_op._expand_batch(batch_broadcast_shape).requires_grad_(linear_op.requires_grad) + for linear_op in linear_ops + ] + super().__init__(*linear_ops) self.linear_ops = linear_ops @@ -102,10 +115,6 @@ def add_diagonal( self: Float[LinearOperator, "*batch N N"], diag: Union[Float[torch.Tensor, "... N"], Float[torch.Tensor, "... 1"], Float[torch.Tensor, ""]], ) -> Float[LinearOperator, "*batch N N"]: - r""" - Adds a diagonal to a KroneckerProductLinearOperator - """ - from .kronecker_product_added_diag_linear_operator import KroneckerProductAddedDiagLinearOperator if not self.is_square: diff --git a/linear_operator/operators/linear_operator_representation_tree.py b/linear_operator/operators/linear_operator_representation_tree.py index 940b9425..838d0d63 100644 --- a/linear_operator/operators/linear_operator_representation_tree.py +++ b/linear_operator/operators/linear_operator_representation_tree.py @@ -1,14 +1,16 @@ #!/usr/bin/env python3 +import itertools class LinearOperatorRepresentationTree(object): def __init__(self, linear_op): self._cls = linear_op.__class__ - self._kwargs = linear_op._kwargs + self._differentiable_kwarg_names = linear_op._differentiable_kwargs.keys() + self._nondifferentiable_kwargs = linear_op._nondifferentiable_kwargs counter = 0 self.children = [] - for arg in linear_op._args: + for arg in itertools.chain(linear_op._args, linear_op._differentiable_kwargs.values()): if hasattr(arg, "representation") and callable(arg.representation): # Is it a lazy tensor? representation_size = len(arg.representation()) self.children.append((slice(counter, counter + representation_size, None), arg.representation_tree())) @@ -27,4 +29,14 @@ def __call__(self, *flattened_representation): sub_representation = flattened_representation[index] unflattened_representation.append(subtree(*sub_representation)) - return self._cls(*unflattened_representation, **self._kwargs) + if len(self._differentiable_kwarg_names): + args = unflattened_representation[: -len(self._differentiable_kwarg_names)] + differentiable_kwargs = dict( + zip( + self._differentiable_kwarg_names, + unflattened_representation[-len(self._differentiable_kwarg_names) :], + ) + ) + return self._cls(*args, **differentiable_kwargs, **self._nondifferentiable_kwargs) + else: + return self._cls(*unflattened_representation, **self._nondifferentiable_kwargs) diff --git a/linear_operator/test/linear_operator_test_case.py b/linear_operator/test/linear_operator_test_case.py index 8f2b79ae..49e6ce74 100644 --- a/linear_operator/test/linear_operator_test_case.py +++ b/linear_operator/test/linear_operator_test_case.py @@ -909,7 +909,7 @@ def test_is_close(self): def test_logdet(self): tolerances = self.tolerances["logdet"] - linear_op = self.create_linear_op() + linear_op = self.create_linear_op().detach() linear_op_copy = linear_op.detach().clone() linear_op.requires_grad_(True) linear_op_copy.requires_grad_(True) diff --git a/test/operators/test_kernel_linear_operator.py b/test/operators/test_kernel_linear_operator.py new file mode 100644 index 00000000..ccaea889 --- /dev/null +++ b/test/operators/test_kernel_linear_operator.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 + +import unittest + +import torch + +from linear_operator.operators import ( + KernelLinearOperator, + KroneckerProductLinearOperator, + MatmulLinearOperator, + RootLinearOperator, +) +from linear_operator.test.linear_operator_test_case import LinearOperatorTestCase, RectangularLinearOperatorTestCase + + +def _covar_func(x1, x2, lengthscale, outputscale): + # RBF kernel function + # x1: ... x N x D + # x2: ... x M x D + # lengthscale: ... x 1 x D + # outputscale: ... + lengthscale = lengthscale.mean(dim=-3) # Remove extraneous dimension added for testing + x1 = x1.div(lengthscale) + x2 = x2.div(lengthscale) + sq_dist = (x1.unsqueeze(-2) - x2.unsqueeze(-3)).square().sum(dim=-1) + kern = sq_dist.div(-2.0).exp().mul(outputscale[..., None, None].square()) + return kern + + +def _nystrom_covar_func(x1, x2, lengthscale, outputscale, inducing_points): + # RBF kernel function w/ Nystrom approximation + # x1: ... x N x D + # x2: ... x M x D + # lengthscale: ... x 1 x D + # outputscale: ... + ones = torch.ones_like(outputscale) + K_zz_chol = _covar_func(inducing_points, inducing_points, lengthscale, ones) + K_zx1 = _covar_func(inducing_points, x1, lengthscale, ones) + K_zx2 = _covar_func(inducing_points, x2, lengthscale, ones) + kern = MatmulLinearOperator( + outputscale[..., None, None] * torch.linalg.solve_triangular(K_zz_chol, K_zx1, upper=False).mT, + outputscale[..., None, None] * torch.linalg.solve_triangular(K_zz_chol, K_zx2, upper=False), + ) + return kern + + +def _multitask_covar_func(x1, x2, lengthscale, outputscale, lmc_coeffs): + # RBF kernel function w/ Nystrom approximation + # x1: ... x N x D + # x2: ... x M x D + # lengthscale: ... x 1 x D + # outputscale: ... + K_xx = _covar_func(x1, x2, lengthscale=lengthscale, outputscale=outputscale) + return KroneckerProductLinearOperator(K_xx, RootLinearOperator(lmc_coeffs)) + + +class TestKernelLinearOperatorRectangular(RectangularLinearOperatorTestCase, unittest.TestCase): + seed = 0 + + def create_linear_op(self): + x1 = torch.randn(3, 1, 5, 6) + x2 = torch.randn(2, 4, 6) + lengthscale = torch.nn.Parameter(torch.ones(4, 1, 6)) + # Adding an extraneous -3 dimension to test functionality + outputscale = torch.nn.Parameter(torch.ones(3, 2)) + return KernelLinearOperator( + x1, + x2, + lengthscale=lengthscale, + outputscale=outputscale, + covar_func=_covar_func, + num_nonbatch_dimensions={"lengthscale": 3, "outputscale": 0}, + ) + + def evaluate_linear_op(self, linop): + return _covar_func(linop.x1, linop.x2, **linop.tensor_params) + + +class TestKernelLinearOperator(LinearOperatorTestCase, unittest.TestCase): + seed = 0 + + def create_linear_op(self): + x = torch.randn(3, 5, 6) + lengthscale = torch.nn.Parameter(torch.ones(3, 4, 1, 6)) + # Adding an extraneous -3 dimension to test functionality + outputscale = torch.nn.Parameter(torch.ones(2, 1)) + return KernelLinearOperator( + x, + x, + lengthscale=lengthscale, + outputscale=outputscale, + covar_func=_covar_func, + num_nonbatch_dimensions={"lengthscale": 3, "outputscale": 0}, + ) + + def evaluate_linear_op(self, linop): + return _covar_func(linop.x1, linop.x2, **linop.tensor_params) + + +class TestKernelLinearOperatorRectangularLinOpReturn(TestKernelLinearOperatorRectangular, unittest.TestCase): + seed = 0 + + def create_linear_op(self): + x1 = torch.randn(3, 4, 6) + x2 = torch.randn(3, 5, 6) + inducing_points = torch.randn(3, 6) + lengthscale = torch.nn.Parameter(torch.ones(3, 4, 1, 6)) + # Adding an extraneous -3 dimension to test functionality + outputscale = torch.nn.Parameter(torch.ones(2, 1)) + return KernelLinearOperator( + x1, + x2, + lengthscale=lengthscale, + outputscale=outputscale, + inducing_points=inducing_points, + covar_func=_nystrom_covar_func, + num_nonbatch_dimensions={"lengthscale": 3, "outputscale": 0}, + ) + + def evaluate_linear_op(self, linop): + return _nystrom_covar_func(linop.x1, linop.x2, **linop.tensor_params).to_dense() + + +class TestKernelLinearOperatorLinOpReturn(TestKernelLinearOperator, unittest.TestCase): + seed = 0 + + def create_linear_op(self): + x = torch.randn(3, 4, 6) + inducing_points = torch.randn(20, 6) # Overparameterized nystrom approx for invertibility + lengthscale = torch.nn.Parameter(torch.ones(3, 4, 1, 6)) + # Adding an extraneous -3 dimension to test functionality + outputscale = torch.nn.Parameter(torch.ones(2, 1)) + return KernelLinearOperator( + x, + x, + lengthscale=lengthscale, + outputscale=outputscale, + inducing_points=inducing_points, + covar_func=_nystrom_covar_func, + num_nonbatch_dimensions={"lengthscale": 3, "outputscale": 0}, + ) + + def evaluate_linear_op(self, linop): + return _nystrom_covar_func(linop.x1, linop.x2, **linop.tensor_params).to_dense() + + +class TestKernelLinearOperatorMultiOutput(TestKernelLinearOperator, unittest.TestCase): + seed = 0 + + def create_linear_op(self): + x = torch.randn(3, 4, 6) + lengthscale = torch.nn.Parameter(torch.ones(3, 4, 1, 6)) + # Adding an extraneous -3 dimension to test functionality + outputscale = torch.nn.Parameter(torch.ones(2, 1)) + lmc_coeffs = torch.nn.Parameter(torch.tensor([[1.0, 0.5], [0.5, 1.0]])) + return KernelLinearOperator( + x, + x, + lengthscale=lengthscale, + outputscale=outputscale, + lmc_coeffs=lmc_coeffs, + covar_func=_multitask_covar_func, + num_outputs_per_input=(2, 2), + num_nonbatch_dimensions={"lengthscale": 3, "outputscale": 0}, + ) + + def evaluate_linear_op(self, linop): + return _multitask_covar_func(linop.x1, linop.x2, **linop.tensor_params).to_dense()