Skip to content

Commit

Permalink
Merge pull request #10 from SebastianAment/add-matmul-edits
Browse files Browse the repository at this point in the history
Edits to generic `add`, `BlockDiagLinearOperator`'s `matmul`, and documentation
  • Loading branch information
Balandat authored Sep 8, 2022
2 parents 2e66c24 + c4fd5fa commit 997156d
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 19 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@ A LinearOperator implementation to wrap the numerical nuts and bolts of GPyTorch

[![Run Test Suite](https://github.com/cornellius-gp/linear_operator/actions/workflows/run_test_suite.yml/badge.svg)](https://github.com/cornellius-gp/linear_operator/actions/workflows/run_test_suite.yml)
[![Documentation Status](https://readthedocs.org/projects/linear-operator/badge/?version=latest)](https://linear-operator.readthedocs.io/en/latest/?badge=latest)

## Development
To run unit tests:
```
python -m unittest discover
```
2 changes: 1 addition & 1 deletion linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2557,7 +2557,7 @@ def __add__(self, other: Union[torch.Tensor, LinearOperator, float]) -> LinearOp
from .zero_linear_operator import ZeroLinearOperator

if isinstance(other, ZeroLinearOperator):
return self
return deepcopy(self)
elif isinstance(other, DiagLinearOperator):
return AddedDiagLinearOperator(self, other)
elif isinstance(other, RootLinearOperator):
Expand Down
13 changes: 11 additions & 2 deletions linear_operator/operators/block_diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ class BlockDiagLinearOperator(BlockLinearOperator, metaclass=_MetaBlockDiagLinea
The dimension that specifies the blocks.
"""

def __init__(self, base_linear_op, block_dim=-3):
super().__init__(base_linear_op, block_dim)
# block diagonal is restricted to have square diagonal blocks
if self.base_linear_op.shape[-1] != self.base_linear_op.shape[-2]:
raise RuntimeError(
"base_linear_op must be a batch of square matrices, but non-batch dimensions are "
f"{base_linear_op.shape[-2:]}"
)

@property
def num_blocks(self):
return self.base_linear_op.size(-3)
Expand Down Expand Up @@ -139,8 +148,8 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)
def matmul(self, other):
from .diag_linear_operator import DiagLinearOperator

# this is trivial if we multiply two BlockDiagLinearOperator
if isinstance(other, BlockDiagLinearOperator):
# this is trivial if we multiply two BlockDiagLinearOperator with matching block sizes
if isinstance(other, BlockDiagLinearOperator) and self.base_linear_op.shape == other.base_linear_op.shape:
return BlockDiagLinearOperator(self.base_linear_op @ other.base_linear_op)
# special case if we have a DiagLinearOperator
if isinstance(other, DiagLinearOperator):
Expand Down
5 changes: 2 additions & 3 deletions linear_operator/operators/block_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class BlockLinearOperator(LinearOperator):
"""
An abstract LinearOperator class for block tensors.
Super classes will determine how the different blocks are layed out
Subclasses will determine how the different blocks are layed out
(e.g. block diagonal, sum over blocks, etc.)
BlockLinearOperators represent the groups of blocks as a batched Tensor.
Expand All @@ -39,7 +39,7 @@ def __init__(self, base_linear_op, block_dim=-3):
block_dim = block_dim if block_dim < 0 else (block_dim - base_linear_op.dim())

# Everything is MUCH easier to write if the last batch dimension is the block dimension
# I.e. blopck_dim = -3
# I.e. block_dim = -3
# We'll permute the dimensions if this is not the case
if block_dim != -3:
positive_block_dim = base_linear_op.dim() + block_dim
Expand All @@ -48,7 +48,6 @@ def __init__(self, base_linear_op, block_dim=-3):
*range(positive_block_dim + 1, base_linear_op.dim() - 2),
positive_block_dim,
)

super(BlockLinearOperator, self).__init__(to_linear_operator(base_linear_op))
self.base_linear_op = base_linear_op

Expand Down
2 changes: 1 addition & 1 deletion linear_operator/operators/cat_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def cat(inputs, dim=0, output_device=None):

class CatLinearOperator(LinearOperator):
r"""
A `LinearOperator` that represents the concatenation of other lazy tensors.
A `LinearOperator` that represents the concatenation of other linear operators.
Each LinearOperator must have the same shape except in the concatenating
dimension.
Expand Down
10 changes: 5 additions & 5 deletions linear_operator/operators/constant_mul_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ConstantMulLinearOperator(LinearOperator):
.. note::
To element-wise multiply two lazy tensors, see :class:`linear_operator.lazy.MulLinearOperator`
To element-wise multiply two lazy tensors, see :class:`linear_operator.operators.MulLinearOperator`
Args:
base_linear_op (LinearOperator) or (b x n x m)): The base_lazy tensor
Expand All @@ -38,18 +38,18 @@ class ConstantMulLinearOperator(LinearOperator):
Example::
>>> base_base_linear_op = linear_operator.lazy.ToeplitzLinearOperator([1, 2, 3])
>>> base_base_linear_op = linear_operator.operators.ToeplitzLinearOperator([1, 2, 3])
>>> constant = torch.tensor(1.2)
>>> new_base_linear_op = linear_operator.lazy.ConstantMulLinearOperator(base_base_linear_op, constant)
>>> new_base_linear_op = linear_operator.operators.ConstantMulLinearOperator(base_base_linear_op, constant)
>>> new_base_linear_op.to_dense()
>>> # Returns:
>>> # [[ 1.2, 2.4, 3.6 ]
>>> # [ 2.4, 1.2, 2.4 ]
>>> # [ 3.6, 2.4, 1.2 ]]
>>>
>>> base_base_linear_op = linear_operator.lazy.ToeplitzLinearOperator([[1, 2, 3], [2, 3, 4]])
>>> base_base_linear_op = linear_operator.operators.ToeplitzLinearOperator([[1, 2, 3], [2, 3, 4]])
>>> constant = torch.tensor([1.2, 0.5])
>>> new_base_linear_op = linear_operator.lazy.ConstantMulLinearOperator(base_base_linear_op, constant)
>>> new_base_linear_op = linear_operator.operators.ConstantMulLinearOperator(base_base_linear_op, constant)
>>> new_base_linear_op.to_dense()
>>> # Returns:
>>> # [[[ 1.2, 2.4, 3.6 ]
Expand Down
6 changes: 3 additions & 3 deletions linear_operator/test/linear_operator_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,19 +877,19 @@ def test_diagonalization(self, symeig=False):
def test_diagonalization_symeig(self):
return self.test_diagonalization(symeig=True)

# NOTE: this is currently not executed, and fails if the underscore is removed
def _test_triangular_linear_op_inv_quad_logdet(self):
# now we need to test that a second cholesky isn't being called in the inv_quad_logdet
with linear_operator.settings.max_cholesky_size(math.inf):
linear_op = self.create_linear_op()
rootdecomp = linear_operator.root_decomposition(linear_op)

if isinstance(rootdecomp, linear_operator.lazy.CholLinearOperator):
if isinstance(rootdecomp, linear_operator.operators.CholLinearOperator):
chol = linear_operator.root_decomposition(linear_op).root.clone()
linear_operator.utils.memoize.clear_cache_hook(linear_op)
linear_operator.utils.memoize.add_to_cache(
linear_op,
"root_decomposition",
linear_operator.lazy.RootLinearOperator(chol),
linear_operator.operators.RootLinearOperator(chol),
)

_wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex)
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/utils/contour_integral_quad.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def contour_integral_quad(
Performs :math:`\mathbf K^{1/2} \mathbf b` or :math:`\mathbf K^{-1/2} \mathbf b`
using contour integral quadrature.
:param linear_operator.lazy.LinearOperator linear_op: LinearOperator representing :math:`\mathbf K`
:param linear_operator.operators.LinearOperator linear_op: LinearOperator representing :math:`\mathbf K`
:param torch.Tensor rhs: Right hand side tensor :math:`\mathbf b`
:param bool inverse: (default False) whether to compute :math:`\mathbf K^{1/2} \mathbf b` (if False)
or `\mathbf K^{-1/2} \mathbf b` (if True)
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/utils/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def apply_permutation(
Broadcasting rules apply.
:param matrix: :math:`\mathbf K`
:type matrix: ~linear_operator.lazy.LinearOperator or ~torch.Tensor (... x n x n)
:type matrix: ~linear_operator.operators.LinearOperator or ~torch.Tensor (... x n x n)
:param left_permutation: vector representing :math:`\boldsymbol{\Pi}_\text{left}`
:type left_permutation: ~torch.Tensor, optional (... x <= n)
:param right_permutation: vector representing :math:`\boldsymbol{\Pi}_\text{right}`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ def _test_solve(self, rhs, lhs=None, cholesky=False):

self.assertFalse(linear_cg_mock.called)

def _test_inv_quad_logdet(self, reduce_inv_quad=True, cholesky=False):
# NOTE: this is currently not executed
def _test_inv_quad_logdet(self, reduce_inv_quad=True, cholesky=False, linear_op=None):
if not self.__class__.skip_slq_tests:
# Forward
linear_op = self.create_linear_op()
if linear_op is None:
linear_op = self.create_linear_op()
evaluated = self.evaluate_linear_op(linear_op)
flattened_evaluated = evaluated.view(-1, *linear_op.matrix_shape)

Expand Down

0 comments on commit 997156d

Please sign in to comment.