From b5b65d78e4808ed38650e506809700b37f8affae Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Wed, 7 Sep 2022 11:42:11 -0400 Subject: [PATCH] deepcopy in add, size check in BlockDiag matmul, doc edits --- linear_operator/operators/_linear_operator.py | 2 +- .../operators/block_diag_linear_operator.py | 7 +++++-- linear_operator/operators/cat_linear_operator.py | 2 +- .../operators/constant_mul_linear_operator.py | 10 +++++----- linear_operator/test/linear_operator_test_case.py | 4 ++-- linear_operator/utils/contour_integral_quad.py | 2 +- linear_operator/utils/permutation.py | 2 +- 7 files changed, 16 insertions(+), 13 deletions(-) diff --git a/linear_operator/operators/_linear_operator.py b/linear_operator/operators/_linear_operator.py index 3c1c1a24..965b1a02 100644 --- a/linear_operator/operators/_linear_operator.py +++ b/linear_operator/operators/_linear_operator.py @@ -2557,7 +2557,7 @@ def __add__(self, other: Union[torch.Tensor, LinearOperator, float]) -> LinearOp from .zero_linear_operator import ZeroLinearOperator if isinstance(other, ZeroLinearOperator): - return self + return deepcopy(self) elif isinstance(other, DiagLinearOperator): return AddedDiagLinearOperator(self, other) elif isinstance(other, RootLinearOperator): diff --git a/linear_operator/operators/block_diag_linear_operator.py b/linear_operator/operators/block_diag_linear_operator.py index af21007d..f9df8681 100644 --- a/linear_operator/operators/block_diag_linear_operator.py +++ b/linear_operator/operators/block_diag_linear_operator.py @@ -139,8 +139,11 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True) def matmul(self, other): from .diag_linear_operator import DiagLinearOperator - # this is trivial if we multiply two BlockDiagLinearOperator - if isinstance(other, BlockDiagLinearOperator): + # this is trivial if we multiply two BlockDiagLinearOperator with matching block sizes + if ( + isinstance(other, BlockDiagLinearOperator) + and self.base_linear_op.shape[-1] == other.base_linear_op.shape[0] + ): return BlockDiagLinearOperator(self.base_linear_op @ other.base_linear_op) # special case if we have a DiagLinearOperator if isinstance(other, DiagLinearOperator): diff --git a/linear_operator/operators/cat_linear_operator.py b/linear_operator/operators/cat_linear_operator.py index dc75e8fb..180c183f 100644 --- a/linear_operator/operators/cat_linear_operator.py +++ b/linear_operator/operators/cat_linear_operator.py @@ -33,7 +33,7 @@ def cat(inputs, dim=0, output_device=None): class CatLinearOperator(LinearOperator): r""" - A `LinearOperator` that represents the concatenation of other lazy tensors. + A `LinearOperator` that represents the concatenation of other linear operators. Each LinearOperator must have the same shape except in the concatenating dimension. diff --git a/linear_operator/operators/constant_mul_linear_operator.py b/linear_operator/operators/constant_mul_linear_operator.py index 197cd330..ab686d1a 100644 --- a/linear_operator/operators/constant_mul_linear_operator.py +++ b/linear_operator/operators/constant_mul_linear_operator.py @@ -21,7 +21,7 @@ class ConstantMulLinearOperator(LinearOperator): .. note:: - To element-wise multiply two lazy tensors, see :class:`linear_operator.lazy.MulLinearOperator` + To element-wise multiply two lazy tensors, see :class:`linear_operator.operators.MulLinearOperator` Args: base_linear_op (LinearOperator) or (b x n x m)): The base_lazy tensor @@ -38,18 +38,18 @@ class ConstantMulLinearOperator(LinearOperator): Example:: - >>> base_base_linear_op = linear_operator.lazy.ToeplitzLinearOperator([1, 2, 3]) + >>> base_base_linear_op = linear_operator.operators.ToeplitzLinearOperator([1, 2, 3]) >>> constant = torch.tensor(1.2) - >>> new_base_linear_op = linear_operator.lazy.ConstantMulLinearOperator(base_base_linear_op, constant) + >>> new_base_linear_op = linear_operator.operators.ConstantMulLinearOperator(base_base_linear_op, constant) >>> new_base_linear_op.to_dense() >>> # Returns: >>> # [[ 1.2, 2.4, 3.6 ] >>> # [ 2.4, 1.2, 2.4 ] >>> # [ 3.6, 2.4, 1.2 ]] >>> - >>> base_base_linear_op = linear_operator.lazy.ToeplitzLinearOperator([[1, 2, 3], [2, 3, 4]]) + >>> base_base_linear_op = linear_operator.operators.ToeplitzLinearOperator([[1, 2, 3], [2, 3, 4]]) >>> constant = torch.tensor([1.2, 0.5]) - >>> new_base_linear_op = linear_operator.lazy.ConstantMulLinearOperator(base_base_linear_op, constant) + >>> new_base_linear_op = linear_operator.operators.ConstantMulLinearOperator(base_base_linear_op, constant) >>> new_base_linear_op.to_dense() >>> # Returns: >>> # [[[ 1.2, 2.4, 3.6 ] diff --git a/linear_operator/test/linear_operator_test_case.py b/linear_operator/test/linear_operator_test_case.py index 41fd94a4..15a89dec 100644 --- a/linear_operator/test/linear_operator_test_case.py +++ b/linear_operator/test/linear_operator_test_case.py @@ -883,13 +883,13 @@ def _test_triangular_linear_op_inv_quad_logdet(self): linear_op = self.create_linear_op() rootdecomp = linear_operator.root_decomposition(linear_op) - if isinstance(rootdecomp, linear_operator.lazy.CholLinearOperator): + if isinstance(rootdecomp, linear_operator.operators.CholLinearOperator): chol = linear_operator.root_decomposition(linear_op).root.clone() linear_operator.utils.memoize.clear_cache_hook(linear_op) linear_operator.utils.memoize.add_to_cache( linear_op, "root_decomposition", - linear_operator.lazy.RootLinearOperator(chol), + linear_operator.operators.RootLinearOperator(chol), ) _wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex) diff --git a/linear_operator/utils/contour_integral_quad.py b/linear_operator/utils/contour_integral_quad.py index 93ad9c16..8b861958 100644 --- a/linear_operator/utils/contour_integral_quad.py +++ b/linear_operator/utils/contour_integral_quad.py @@ -23,7 +23,7 @@ def contour_integral_quad( Performs :math:`\mathbf K^{1/2} \mathbf b` or :math:`\mathbf K^{-1/2} \mathbf b` using contour integral quadrature. - :param linear_operator.lazy.LinearOperator linear_op: LinearOperator representing :math:`\mathbf K` + :param linear_operator.operators.LinearOperator linear_op: LinearOperator representing :math:`\mathbf K` :param torch.Tensor rhs: Right hand side tensor :math:`\mathbf b` :param bool inverse: (default False) whether to compute :math:`\mathbf K^{1/2} \mathbf b` (if False) or `\mathbf K^{-1/2} \mathbf b` (if True) diff --git a/linear_operator/utils/permutation.py b/linear_operator/utils/permutation.py index ba26e7b3..042602c0 100644 --- a/linear_operator/utils/permutation.py +++ b/linear_operator/utils/permutation.py @@ -30,7 +30,7 @@ def apply_permutation( Broadcasting rules apply. :param matrix: :math:`\mathbf K` - :type matrix: ~linear_operator.lazy.LinearOperator or ~torch.Tensor (... x n x n) + :type matrix: ~linear_operator.operators.LinearOperator or ~torch.Tensor (... x n x n) :param left_permutation: vector representing :math:`\boldsymbol{\Pi}_\text{left}` :type left_permutation: ~torch.Tensor, optional (... x <= n) :param right_permutation: vector representing :math:`\boldsymbol{\Pi}_\text{right}`