Skip to content

Commit

Permalink
Add KernelLinearOperator, deprecate KeOpsLinearOperator (#62)
Browse files Browse the repository at this point in the history
* Add KernelLinearOperator, deprecate KeOpsLinearOperator

KeOpsLinearOperator does not correctly backpropagate gradients if the
covar_func closes over parameters.

KernelLinearOperator corrects for this, and is set up to replace
LazyEvaluatedKernelTensor in GPyTorch down the line.

* Fix KeOpsLinearOperator deprecation

* Allow for kernels with reduced batches and multiple outputs per input

* LinearOperator kwargs can be differentiated through

Previously, only positional args were added to the LinearOperator
representation, and so only positional args would receive gradients from
_bilinear_derivative.

This commit also adds Tensor/LinearOperator kwargs to the
representation, and so kwarg Tensor/LinearOperators will also receive
gradients.

* Hyperparameters for KernelLinearOperator must be kwargs

* LO._bilinear_derivative only computes derivatives for args that require gradients

* Expand upon closure variables warning for KernelLinearOperator

* LO._bilinear_derivative exits early if no parameters require gradients

* Refactor KernelLinearOperator._getitem

* Allow for optional number of nonbatch parameter dimensions

* Fix LO._bilinear_derivative

* Update linear_operator/operators/_linear_operator.py

Co-authored-by: Max Balandat <[email protected]>

* Update linear_operator/operators/kernel_linear_operator.py

Co-authored-by: Max Balandat <[email protected]>

* Update linear_operator/operators/linear_operator_representation_tree.py

Co-authored-by: Max Balandat <[email protected]>

* Update linear_operator/operators/kernel_linear_operator.py

Co-authored-by: Max Balandat <[email protected]>

* Update linear_operator/operators/kernel_linear_operator.py

Co-authored-by: Max Balandat <[email protected]>

* Update linear_operator/operators/kernel_linear_operator.py

Co-authored-by: Max Balandat <[email protected]>

* Update linear_operator/operators/kernel_linear_operator.py

Co-authored-by: Max Balandat <[email protected]>

* Update linear_operator/operators/kernel_linear_operator.py

Co-authored-by: Max Balandat <[email protected]>

* Update linear_operator/operators/kernel_linear_operator.py

Co-authored-by: Max Balandat <[email protected]>

* Update linear_operator/operators/kernel_linear_operator.py

Co-authored-by: Max Balandat <[email protected]>

* Update linear_operator/operators/kernel_linear_operator.py

Co-authored-by: Max Balandat <[email protected]>

* Fix errors, address comments

* KroneckerProductLinearOperator broadcasts

* Test cases and fixes for multitask KernelLinearOperator

---------

Co-authored-by: Max Balandat <[email protected]>
  • Loading branch information
gpleiss and Balandat committed Jun 2, 2023
1 parent f020146 commit 7affaf3
Show file tree
Hide file tree
Showing 10 changed files with 699 additions and 35 deletions.
24 changes: 18 additions & 6 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,20 @@ def _dim_to_str(dim):
if isinstance(dim, jaxtyping.array_types._NamedVariadicDim):
return "..."
elif isinstance(dim, jaxtyping.array_types._FixedDim):
return str(dim.size)
res = str(dim.size)
if dim.broadcastable:
res = "#" + res
return res
elif isinstance(dim, jaxtyping.array_types._SymbolicDim):
expr = code_deparse(dim.expr).text.strip().split("return ")[1]
return f"({expr})"
elif "jaxtyping" not in str(dim.__class__): # Probably the case that we have an ellipsis
return "..."
else:
return str(dim.name)
res = str(dim.name)
if dim.broadcastable:
res = "#" + res
return res


# Function to format type hints
Expand Down Expand Up @@ -152,9 +158,15 @@ def _process(annotation, config):
elif hasattr(annotation, "__name__"):
res = _convert_internal_and_external_class_to_strings(annotation)

elif str(annotation).startswith("typing.Callable"):
if len(annotation.__args__) == 2:
res = f"Callable[{_process(annotation.__args__[0], config)} -> {_process(annotation.__args__[1], config)}]"
else:
res = "Callable"

# Convert any Union[*A*, *B*, *C*] into "*A* or *B* or *C*"
# Also, convert any Optional[*A*] into "*A*, optional"
elif "typing.Union" in str(annotation):
elif str(annotation).startswith("typing.Union"):
is_optional_str = ""
args = list(annotation.__args__)
# Hack: Optional[*A*] are represented internally as Union[*A*, Nonetype]
Expand All @@ -166,13 +178,13 @@ def _process(annotation, config):
res = " or ".join(processed_args) + is_optional_str

# Convert any Tuple[*A*, *B*] into "(*A*, *B*)"
elif "typing.Tuple" in str(annotation):
elif str(annotation).startswith("typing.Tuple"):
args = list(annotation.__args__)
res = "(" + ", ".join(_process(arg, config) for arg in args) + ")"

# Callable typing annotation
elif "typing." in str(annotation):
return str(annotation)
elif str(annotation).startswith("typing."):
return str(annotation)[7:]

# Special cases for forward references.
# This is brittle, as it only contains case for a select few forward refs
Expand Down
6 changes: 6 additions & 0 deletions docs/source/data_sparse_operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ Data-Sparse LinearOperators
.. autoclass:: linear_operator.operators.IdentityLinearOperator
:members:

:hidden:`KernelLinearOperator`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: linear_operator.operators.KernelLinearOperator
:members:

:hidden:`RootLinearOperator`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions linear_operator/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .identity_linear_operator import IdentityLinearOperator
from .interpolated_linear_operator import InterpolatedLinearOperator
from .keops_linear_operator import KeOpsLinearOperator
from .kernel_linear_operator import KernelLinearOperator
from .kronecker_product_added_diag_linear_operator import KroneckerProductAddedDiagLinearOperator
from .kronecker_product_linear_operator import (
KroneckerProductDiagLinearOperator,
Expand Down Expand Up @@ -53,6 +54,7 @@
"IdentityLinearOperator",
"InterpolatedLinearOperator",
"KeOpsLinearOperator",
"KernelLinearOperator",
"KroneckerProductLinearOperator",
"KroneckerProductAddedDiagLinearOperator",
"KroneckerProductDiagLinearOperator",
Expand Down
46 changes: 35 additions & 11 deletions linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from __future__ import annotations

import functools
import itertools
import math
import numbers
import warnings
from abc import abstractmethod
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -150,7 +152,14 @@ def __init__(self, *args, **kwargs):
raise ValueError(err)

self._args = args
self._kwargs = kwargs
self._differentiable_kwargs = OrderedDict()
self._nondifferentiable_kwargs = dict()
for name, val in sorted(kwargs.items()):
# Sorting is necessary so that the flattening in the representation tree is deterministic
if torch.is_tensor(val) or isinstance(val, LinearOperator):
self._differentiable_kwargs[name] = val
else:
self._nondifferentiable_kwargs[name] = val

####
# The following methods need to be defined by the LinearOperator
Expand Down Expand Up @@ -350,17 +359,24 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O
"""
from collections import deque

args = tuple(self.representation())
args_with_grads = tuple(arg for arg in args if arg.requires_grad)
# Construct a detached version of each argument in the linear operator
args = []
for arg in self.representation():
# All arguments here are guaranteed to be tensors
if arg.dtype.is_floating_point and arg.requires_grad:
args.append(arg.detach().requires_grad_(True))
else:
args.append(arg.detach())

# Easy case: if we don't require any gradients, then just return!
if not len(args_with_grads):
return tuple(None for _ in args)
# If no arguments require gradients, then we're done!
if not any(arg.requires_grad for arg in args):
return (None,) * len(args)

# Normal case: we'll use the autograd to get us a derivative
# We'll use the autograd to get us a derivative
with torch.autograd.enable_grad():
loss = (left_vecs * self._matmul(right_vecs)).sum()
loss.requires_grad_(True)
lin_op = self.representation_tree()(*args)
loss = (left_vecs * lin_op._matmul(right_vecs)).sum()
args_with_grads = [arg for arg in args if arg.requires_grad]
actual_grads = deque(torch.autograd.grad(loss, args_with_grads, allow_unused=True))

# Now make sure that the object we return has one entry for every item in args
Expand Down Expand Up @@ -457,6 +473,10 @@ def _args(self) -> Tuple[Union[torch.Tensor, "LinearOperator", int], ...]:
def _args(self, args: Tuple[Union[torch.Tensor, "LinearOperator", int], ...]) -> None:
self._args_memo = args

@property
def _kwargs(self) -> Dict[str, Any]:
return {**self._differentiable_kwargs, **self._nondifferentiable_kwargs}

def _approx_diagonal(self: Float[LinearOperator, "*batch N N"]) -> Float[torch.Tensor, "*batch N"]:
"""
(Optional) returns an (approximate) diagonal of the matrix
Expand Down Expand Up @@ -1344,7 +1364,11 @@ def detach(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "
(In practice, this function removes all Tensors that make up the
:obj:`~linear_operator.opeators.LinearOperator` from the computation graph.)
"""
return self.clone().detach_()
detached_args = [arg.detach() if hasattr(arg, "detach") else arg for arg in self._args]
detached_kwargs = dict(
(key, val.detach() if hasattr(val, "detach") else val) for key, val in self._kwargs.items()
)
return self.__class__(*detached_args, **detached_kwargs)

def detach_(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]:
"""
Expand Down Expand Up @@ -2013,7 +2037,7 @@ def representation(self) -> Tuple[torch.Tensor, ...]:
Returns the Tensors that are used to define the LinearOperator
"""
representation = []
for arg in self._args:
for arg in itertools.chain(self._args, self._differentiable_kwargs.values()):
if torch.is_tensor(arg):
representation.append(arg)
elif hasattr(arg, "representation") and callable(arg.representation): # Is it a LinearOperator?
Expand Down
6 changes: 6 additions & 0 deletions linear_operator/operators/keops_linear_operator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import warnings

from typing import Optional, Tuple, Union

import torch
Expand All @@ -13,6 +15,10 @@

class KeOpsLinearOperator(LinearOperator):
def __init__(self, x1, x2, covar_func, **params):
warnings.warn(
"KeOpsLinearOperator is deprecated. Please use KernelLinearOperator instead.",
DeprecationWarning,
)
super().__init__(x1, x2, covar_func=covar_func, **params)

self.x1 = x1.contiguous()
Expand Down
Loading

0 comments on commit 7affaf3

Please sign in to comment.