Skip to content

Commit

Permalink
Update docs for AddedDiag, Identity, Zero
Browse files Browse the repository at this point in the history
  • Loading branch information
gpleiss committed Aug 16, 2022
1 parent 3e85dc8 commit 987df55
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 92 deletions.
6 changes: 6 additions & 0 deletions docs/source/data_sparse_operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ Data-Sparse LinearOperators
.. autoclass:: linear_operator.operators.DiagLinearOperator
:members:

:hidden:`IdentityLinearOperator`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: linear_operator.operators.IdentityLinearOperator
:members:

:hidden:`RootLinearOperator`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
5 changes: 3 additions & 2 deletions linear_operator/operators/added_diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

class AddedDiagLinearOperator(SumLinearOperator):
"""
A SumLinearOperator, but of only two linear operators, the second of which must be
a DiagLinearOperator.
A :class:`~linear_operator.operators.SumLinearOperator`, but of only two
linear operators, the second of which must be a
:class:`~linear_operator.operators.DiagLinearOperator`.
:param linear_ops: The LinearOperator, and the DiagLinearOperator to add to it.
:param preconditioner_override: A preconditioning method to be used with conjugate gradients.
Expand Down
105 changes: 59 additions & 46 deletions linear_operator/operators/identity_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
from torch import Tensor
Expand All @@ -11,19 +11,27 @@
from ..utils.memoize import cached
from ._linear_operator import LinearOperator
from .diag_linear_operator import ConstantDiagLinearOperator
from .triangular_linear_operator import TriangularLinearOperator
from .zero_linear_operator import ZeroLinearOperator


class IdentityLinearOperator(ConstantDiagLinearOperator):
def __init__(self, diag_shape, batch_shape=torch.Size([]), dtype=None, device=None):
"""
Identity matrix lazy tensor. Supports arbitrary batch sizes.
Args:
:attr:`diag` (Tensor):
A `b1 x ... x bk x n` Tensor, representing a `b1 x ... x bk`-sized batch
of `n x n` identity matrices
"""
"""
Identity linear operator. Supports arbitrary batch sizes.
:param diag_shape: The size of the identity matrix (i.e. :math:`N`).
:param batch_shape: The size of the batch dimensions. It may useful to set these dimensions for broadcasting.
:param dtype: Dtype that the LinearOperator will be operating on. (Default: :meth:`torch.get_default_dtype()`).
:param device: Device that the LinearOperator will be operating on. (Default: CPU).
"""

def __init__(
self,
diag_shape: int,
batch_shape: Optional[torch.Size] = torch.Size([]),
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
one = torch.tensor(1.0, dtype=dtype, device=device)
LinearOperator.__init__(self, diag_shape=diag_shape, batch_shape=batch_shape, dtype=dtype, device=device)
self.diag_values = one.expand(torch.Size([*batch_shape, 1]))
Expand All @@ -33,40 +41,42 @@ def __init__(self, diag_shape, batch_shape=torch.Size([]), dtype=None, device=No
self._device = device

@property
def batch_shape(self):
"""
Returns the shape over which the tensor is batched.
"""
def batch_shape(self) -> torch.Size:
return self._batch_shape

@property
def dtype(self):
def dtype(self) -> torch.dtype:
return self._dtype

@property
def device(self):
def device(self) -> torch.device:
return self._device

def _maybe_reshape_rhs(self, rhs):
def _maybe_reshape_rhs(self, rhs: torch.Tensor) -> torch.Tensor:
if self._batch_shape != rhs.shape[:-2]:
batch_shape = torch.broadcast_shapes(rhs.shape[:-2], self._batch_shape)
return rhs.expand(*batch_shape, *rhs.shape[-2:])
else:
return rhs

@cached(name="cholesky", ignore_args=True)
def _cholesky(self, upper=False):
def _cholesky(self, upper: Optional[bool] = False) -> TriangularLinearOperator:
return self

def _cholesky_solve(self, rhs):
def _cholesky_solve(self, rhs: torch.Tensor) -> torch.Tensor:
return self._maybe_reshape_rhs(rhs)

def _expand_batch(self, batch_shape):
def _expand_batch(self, batch_shape: torch.Size) -> LinearOperator:
return IdentityLinearOperator(
diag_shape=self.diag_shape, batch_shape=batch_shape, dtype=self.dtype, device=self.device
)

def _getitem(self, row_index, col_index, *batch_indices):
def _getitem(
self,
row_index: Union[slice, torch.LongTensor],
col_index: Union[slice, torch.LongTensor],
*batch_indices: Tuple[Union[int, slice, torch.LongTensor], ...],
) -> LinearOperator:
# Special case: if both row and col are not indexed, then we are done
if _is_noop_index(row_index) and _is_noop_index(col_index):
if len(batch_indices):
Expand All @@ -80,35 +90,39 @@ def _getitem(self, row_index, col_index, *batch_indices):

return super()._getitem(row_index, col_index, *batch_indices)

def _matmul(self, rhs):
def _matmul(self, rhs: torch.Tensor) -> torch.Tensor:
return self._maybe_reshape_rhs(rhs)

def _mul_constant(self, constant):
return ConstantDiagLinearOperator(self.diag_values * constant, diag_shape=self.diag_shape)
def _mul_constant(self, other: Union[float, torch.Tensor]) -> LinearOperator:
return ConstantDiagLinearOperator(self.diag_values * other, diag_shape=self.diag_shape)

def _mul_matrix(self, other):
def _mul_matrix(self, other: Union[torch.Tensor, LinearOperator]) -> LinearOperator:
return other

def _permute_batch(self, *dims):
def _permute_batch(self, *dims: Tuple[int, ...]) -> LinearOperator:
batch_shape = self.diag_values.permute(*dims, -1).shape[:-1]
return IdentityLinearOperator(
diag_shape=self.diag_shape, batch_shape=batch_shape, dtype=self._dtype, device=self._device
)

def _prod_batch(self, dim):
def _prod_batch(self, dim: int) -> LinearOperator:
batch_shape = list(self.batch_shape)
del batch_shape[dim]
return IdentityLinearOperator(
diag_shape=self.diag_shape, batch_shape=torch.Size(batch_shape), dtype=self.dtype, device=self.device
)

def _root_decomposition(self):
def _root_decomposition(self) -> LinearOperator:
return self.sqrt()

def _root_inv_decomposition(self, initial_vectors=None):
def _root_inv_decomposition(
self,
initial_vectors: Optional[torch.Tensor] = None,
test_vectors: Optional[torch.Tensor] = None,
) -> LinearOperator:
return self.inverse().sqrt()

def _size(self):
def _size(self) -> torch.Size:
return torch.Size([*self._batch_shape, self.diag_shape, self.diag_shape])

@cached(name="svd")
Expand All @@ -118,10 +132,10 @@ def _svd(self) -> Tuple[LinearOperator, Tensor, LinearOperator]:
def _symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional[LinearOperator]]:
return self._diag, self

def _t_matmul(self, rhs):
def _t_matmul(self, rhs: torch.Tensor) -> LinearOperator:
return self._maybe_reshape_rhs(rhs)

def _transpose_nonbatch(self):
def _transpose_nonbatch(self) -> LinearOperator:
return self

def _unsqueeze_batch(self, dim: int) -> IdentityLinearOperator:
Expand All @@ -132,16 +146,18 @@ def _unsqueeze_batch(self, dim: int) -> IdentityLinearOperator:
diag_shape=self.diag_shape, batch_shape=batch_shape, dtype=self.dtype, device=self.device
)

def abs(self):
def abs(self) -> LinearOperator:
return self

def exp(self):
def exp(self) -> LinearOperator:
return self

def inverse(self):
def inverse(self) -> LinearOperator:
return self

def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True):
def inv_quad_logdet(
self, inv_quad_rhs: Optional[torch.Tensor] = None, logdet: bool = False, reduce_inv_quad: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO: Use proper batching for inv_quad_rhs (prepand to shape rathern than append)
if inv_quad_rhs is None:
inv_quad_term = torch.empty(0, dtype=self.dtype, device=self.device)
Expand All @@ -158,12 +174,12 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)

return inv_quad_term, logdet_term

def log(self):
def log(self) -> LinearOperator:
return ZeroLinearOperator(
*self._batch_shape, self.diag_shape, self.diag_shape, dtype=self._dtype, device=self._device
)

def matmul(self, other):
def matmul(self, other: Union[torch.Tensor, LinearOperator]) -> Union[torch.Tensor, LinearOperator]:
is_vec = False
if other.dim() == 1:
is_vec = True
Expand All @@ -173,31 +189,28 @@ def matmul(self, other):
res = res.squeeze(-1)
return res

def solve(self, right_tensor, left_tensor=None):
def solve(self, right_tensor: torch.Tensor, left_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
res = self._maybe_reshape_rhs(right_tensor)
if left_tensor is not None:
res = left_tensor @ res
return res

def sqrt(self):
def sqrt(self) -> LinearOperator:
return self

def sqrt_inv_matmul(self, rhs, lhs=None):
def sqrt_inv_matmul(self, rhs: torch.Tensor, lhs: Optional[torch.Tensor] = None) -> torch.Tensor:
if lhs is None:
return self._maybe_reshape_rhs(rhs)
else:
sqrt_inv_matmul = lhs @ rhs
inv_quad = lhs.pow(2).sum(dim=-1)
return sqrt_inv_matmul, inv_quad

def type(self, dtype):
"""
This method operates similarly to :func:`torch.Tensor.type`.
"""
def type(self, dtype: torch.dtype) -> LinearOperator:
return IdentityLinearOperator(
diag_shape=self.diag_shape, batch_shape=self.batch_shape, dtype=dtype, device=self.device
)

def zero_mean_mvn_samples(self, num_samples):
def zero_mean_mvn_samples(self, num_samples: int) -> torch.Tensor:
base_samples = torch.randn(num_samples, *self.shape[:-1], dtype=self.dtype, device=self.device)
return base_samples
Loading

0 comments on commit 987df55

Please sign in to comment.