diff --git a/.conda/meta.yaml b/.conda/meta.yaml index 40a5696e..068a519f 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -18,6 +18,8 @@ requirements: run: - pytorch>=1.11 - scipy + - jaxtyping>=0.2.9 + - typeguard~=2.13.3 test: imports: diff --git a/linear_operator/operators/__init__.py b/linear_operator/operators/__init__.py index 06fbc79f..0ed19b87 100644 --- a/linear_operator/operators/__init__.py +++ b/linear_operator/operators/__init__.py @@ -6,6 +6,7 @@ from .block_diag_linear_operator import BlockDiagLinearOperator from .block_interleaved_linear_operator import BlockInterleavedLinearOperator from .block_linear_operator import BlockLinearOperator +from .block_matrix_linear_operator import BlockMatrixLinearOperator from .cat_linear_operator import cat, CatLinearOperator from .chol_linear_operator import CholLinearOperator from .constant_mul_linear_operator import ConstantMulLinearOperator @@ -46,6 +47,7 @@ "BlockLinearOperator", "BlockDiagLinearOperator", "BlockInterleavedLinearOperator", + "BlockMatrixLinearOperator", "CatLinearOperator", "CholLinearOperator", "ConstantDiagLinearOperator", diff --git a/linear_operator/operators/block_matrix_linear_operator.py b/linear_operator/operators/block_matrix_linear_operator.py new file mode 100644 index 00000000..9aba135a --- /dev/null +++ b/linear_operator/operators/block_matrix_linear_operator.py @@ -0,0 +1,170 @@ +import math +from typing import List, Optional, Union + +import torch +from jaxtyping import Float +from torch import Tensor + +from .. import settings +from ._linear_operator import IndexType, LinearOperator, to_dense +from .dense_linear_operator import DenseLinearOperator +from .zero_linear_operator import ZeroLinearOperator + + +class BlockMatrixLinearOperator(LinearOperator): + """ + A TxT block matrix of LinearOperators. + + Idea. Represent [TN, TM] tensors by TxT blocks of NxM lazy tensors. + + Implementation. A block linear operator class that can keep track of the [T, T] block structure, + represented as T^2 lazy tensors of the same shape. Implement matrix multiplication between block matrices as + the appropriate linear operators on the blocks. + + :param linear_operators: A T^2 (flattened) list of linear operators representing a 2-D TxT block matrix. + The list of linear operators should be flattened into a concatenation of block-rowsa. + """ + + def __init__(self, *flattened_linear_operators: LinearOperator) -> None: + self.num_tasks = int(math.sqrt(len(flattened_linear_operators))) + + if settings.debug.on(): + assert len(flattened_linear_operators) > 0, "must have non-empty list" + assert self.num_tasks**2 == len(flattened_linear_operators) + + super().__init__(*flattened_linear_operators) + + self.linear_operators = tuple( + flattened_linear_operators[i * self.num_tasks : (i + 1) * self.num_tasks] for i in range(self.num_tasks) + ) + self.block_rows = self.linear_operators[0][0].shape[0] + self.block_cols = self.linear_operators[0][0].shape[1] + + @staticmethod + def create_square_ops_output(T: int) -> List[List[LinearOperator]]: + """Return an empty (square) list of operators of shape TxT""" + ops = [] + for i in range(T): + tmp = [] + for j in range(T): + tmp.append([]) + ops.append(tmp) + return ops + + def _matmul_two_block_matrix_linear_operators( + self: "BlockMatrixLinearOperator", + other: "BlockMatrixLinearOperator", + ) -> "BlockMatrixLinearOperator": + assert self.num_tasks == other.num_tasks + assert self.block_cols == other.block_rows + + T = self.num_tasks + output = [] + for i in range(T): + for j in range(T): + out_ij = self.linear_operators[i][0] @ other.linear_operators[0][j] + for k in range(1, T): + out_ij += self.linear_operators[i][k] @ other.linear_operators[k][j] + output.append(out_ij) + return self.__class__(*output) + + def _matmul( + self: Float[LinearOperator, "*batch M N"], + rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], + ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + T = self.num_tasks + + if isinstance(rhs, Tensor) and rhs.ndim == 2: + # Check both matrix dims divisible by T, + # reshape to (T, T, ), call block multiplication + if rhs.size(0) % T == 0 and rhs.size(1) % T == 0: + # A is block [N * T, M * T] and B is a general tensor/operator of shape [O, P]. + # If O and P are both divisible by T, + # then interpret B as a [O//T * T, P//T * T] block matrix + O_T = rhs.size(0) // T + P_T = rhs.size(1) // T + rhs_blocks_raw = rhs.reshape(T, O_T, T, P_T) + rhs_blocks = rhs_blocks_raw.permute(0, 2, 1, 3) + rhs_op = BlockMatrixLinearOperator.from_tensor(rhs_blocks, T) + return self._matmul_two_block_matrix_linear_operators(rhs_op).to_dense() + + # Failover implementation. Convert to dense and multiply matricies + # Batch logic is not supported for now + assert rhs.dim() <= 2 + A = self.to_dense() + B = to_dense(rhs) + + res = A @ B + return res + + def matmul( + self: Float[LinearOperator, "*batch M N"], + other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]], + ) -> Union[Float[Tensor, "... M P"], Float[Tensor, "... M"], Float[LinearOperator, "... M P"]]: + # A is block [N * T1, M * T2] and B is block [O * S1, P * S2]. If A and B have conformal block counts + # ie T2==S1 as well as M==O then use the blockwise algorithm. Else use to_dense() + if isinstance(other, self.__class__): + if self.num_tasks == other.num_tasks and self.block_cols == other.block_rows: + return self._matmul_two_block_matrix_linear_operators(other) + elif isinstance(other, LinearOperator): + from .matmul_linear_operator import MatmulLinearOperator + + return MatmulLinearOperator(self, other) + + # The base method wants to perform a matmul via broadcasting and a + # representation tree which this operator doesn't support. + return self._matmul(other) + + def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + out = [] + for i in range(self.num_tasks): + rows = [] + for j in range(self.num_tasks): + rows.append(self.linear_operators[i][j].to_dense()) + out.append(torch.concat(rows, axis=1)) + return torch.concat(out, axis=0) + + def _size(self) -> torch.Size: + sz = self.linear_operators[0][0].size() + return torch.Size([self.num_tasks * sz[0], self.num_tasks * sz[1]]) + + @property + def dtype(self) -> Optional[torch.dtype]: + return self.linear_operators[0][0].dtype + + @property + def device(self) -> Optional[torch.device]: + return self.linear_operators[0][0].device + + def _diag(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + out = [] + for i in range(self.num_tasks): + # The underlying operators will test if they are square + diagonal = self.linear_operators[i][i].diagonal() + out.append(diagonal) + return torch.concat(out, axis=1) + + def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + out = [] + for i in range(self.num_tasks): + for j in range(self.num_tasks): + out.append(self.linear_operators[j][i].mT) + return BlockMatrixLinearOperator(*out) + + def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> LinearOperator: + # Perform the __getitem__ + tsr = self.to_dense() + res = tsr[(*batch_indices, row_index, col_index)] + return DenseLinearOperator(res) + + @classmethod + def from_tensor(cls, tensor: Tensor, num_tasks: int) -> "BlockMatrixLinearOperator": + def tensor_to_linear_op(t: Tensor) -> LinearOperator: + if torch.count_nonzero(t) > 0: + return DenseLinearOperator(t) + return ZeroLinearOperator(*t.size(), dtype=t.dtype, device=t.device) + + linear_ops = [ + tensor_to_linear_op(t[0]) for i in range(num_tasks) for t in list(torch.tensor_split(tensor[i], num_tasks)) + ] + return cls(*linear_ops) diff --git a/linear_operator/test/linear_operator_core_test_case.py b/linear_operator/test/linear_operator_core_test_case.py new file mode 100644 index 00000000..005a9f6c --- /dev/null +++ b/linear_operator/test/linear_operator_core_test_case.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 + +from abc import abstractmethod + +import torch + +import linear_operator +from linear_operator.operators import DiagLinearOperator, to_dense +from .base_test_case import BaseTestCase + +""" +From the project description, a LinearOperator is a class that: + +- specifies the tensor(s) needed to define the LinearOperator, +- specifies a _matmul function (how the LinearOperator is applied to a vector), +- specifies a _size function (how big is the LinearOperator if it is represented as a matrix, or batch of matrices), and +- specifies a _transpose_nonbatch function (the adjoint of the LinearOperator). +- (optionally) defines other functions (e.g. logdet, eigh, etc.) to accelerate computations for which efficient + sturcture-exploiting routines exist. + +What follows is a class to test these core LinearOperator operations. +Note that batch operations are excluded here since they are not part of the core definition. +""" + + +class CoreLinearOperatorTestCase(BaseTestCase): + """Test the core operations for a LinearOperator""" + + tolerances = { + "matmul": {"rtol": 1e-3}, + "transpose": {"rtol": 1e-4, "atol": 1e-5}, + } + + @abstractmethod + def create_linear_op(self): + raise NotImplementedError() + + @abstractmethod + def evaluate_linear_op(self): + raise NotImplementedError() + + def _test_matmul(self, rhs): + linear_op = self.create_linear_op().detach().requires_grad_(True) + linear_op_copy = torch.clone(linear_op).detach().requires_grad_(True) + evaluated = self.evaluate_linear_op(linear_op_copy) + rhs_evaluated = to_dense(rhs) + + # Test operator + res = linear_op @ rhs + actual = evaluated.matmul(rhs_evaluated) + res_evaluated = to_dense(res) + self.assertAllClose(res_evaluated, actual) + + # Test __torch_function__ + res = torch.matmul(linear_op, rhs) + actual = evaluated.matmul(rhs) + self.assertAllClose(to_dense(res), actual) + + def test_transpose_nonbatch(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + res = linear_op._transpose_nonbatch() + actual = evaluated.mT + res_evaluated = to_dense(res) + self.assertAllClose(res_evaluated, actual, **self.tolerances["transpose"]) + + def _test_rmatmul(self, lhs): + # Note. transpose_nonbatch is tested implicitly here because + # the base linear operator class defines + # def rmatmul(other): + # return self.mT.matmul(other.mT).mT + linear_op = self.create_linear_op().detach().requires_grad_(True) + linear_op_copy = torch.clone(linear_op).detach().requires_grad_(True) + evaluated = self.evaluate_linear_op(linear_op_copy) + + # Test operator + res = lhs @ linear_op + res_evaluated = to_dense(res) + actual = lhs @ evaluated + self.assertAllClose(res_evaluated, actual) + + # Test __torch_function__ + res = torch.matmul(lhs, linear_op) + res_evaluated = to_dense(res) + actual = torch.matmul(lhs, evaluated) + self.assertAllClose(res_evaluated, actual) + + def test_add(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + rhs = torch.randn(linear_op.shape) + # Test operator functionality + a = (linear_op + rhs).to_dense() + b = evaluated + rhs + self.assertAllClose(a, b) + self.assertAllClose((linear_op + rhs).to_dense(), evaluated + rhs) + self.assertAllClose((rhs + linear_op).to_dense(), evaluated + rhs) + # Test __torch_function__ functionality + self.assertAllClose(torch.add(linear_op, rhs).to_dense(), evaluated + rhs) + self.assertAllClose(torch.add(rhs, linear_op).to_dense(), evaluated + rhs) + + rhs = torch.randn(linear_op.matrix_shape) + self.assertAllClose((linear_op + rhs).to_dense(), evaluated + rhs) + + # rhs = torch.randn(2, *linear_op.shape) + # self.assertAllClose((linear_op + rhs).to_dense(), evaluated + rhs) + + self.assertAllClose((linear_op + linear_op).to_dense(), evaluated * 2) + return linear_op, evaluated + + def test_matmul_vec(self): + linear_op = self.create_linear_op() + + # We skip this test if we're dealing with batch LinearOperators + # They shouldn't multiply by a vec + if linear_op.ndimension() > 2: + return + + rhs = torch.randn(linear_op.size(-1)) + return self._test_matmul(rhs) + + def test_constant_mul(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + # Test operator functionality + self.assertAllClose((linear_op * 5.0).to_dense(), evaluated * 5.0) + self.assertAllClose((linear_op * torch.tensor(5.0)).to_dense(), evaluated * 5.0) + self.assertAllClose((5.0 * linear_op).to_dense(), evaluated * 5.0) + self.assertAllClose((torch.tensor(5.0) * linear_op).to_dense(), evaluated * 5.0) + + # Test __torch_function__ functionality + self.assertAllClose(torch.mul(linear_op, torch.tensor(5.0)).to_dense(), evaluated * 5.0) + self.assertAllClose(torch.mul(torch.tensor(5.0), linear_op).to_dense(), evaluated * 5.0) + + def test_constant_mul_neg(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + self.assertAllClose((linear_op * -5.0).to_dense(), evaluated * -5.0) + + def test_constant_div(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + # Test operator functionality + self.assertAllClose((linear_op / 5.0).to_dense(), evaluated / 5.0) + self.assertAllClose((linear_op / torch.tensor(5.0)).to_dense(), evaluated / 5.0) + + # Test __torch_function__ functionality + self.assertAllClose(torch.div(linear_op, torch.tensor(5.0)).to_dense(), evaluated / 5.0) + + def test_to_dense(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + self.assertAllClose(linear_op.to_dense(), evaluated) + + def test_getitem(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + # Non-batch case + if linear_op.ndimension() == 2: + res = linear_op[1] + actual = evaluated[1] + self.assertAllClose(res, actual) + res = linear_op[0:2].to_dense() + actual = evaluated[0:2] + self.assertAllClose(res, actual) + res = linear_op[:, 0:2].to_dense() + actual = evaluated[:, 0:2] + self.assertAllClose(res, actual) + res = linear_op[0:2, :].to_dense() + actual = evaluated[0:2, :] + self.assertAllClose(res, actual) + res = linear_op[..., 0:2].to_dense() + actual = evaluated[..., 0:2] + self.assertAllClose(res, actual) + res = linear_op[0:2, ...].to_dense() + actual = evaluated[0:2, ...] + self.assertAllClose(res, actual) + res = linear_op[..., 0:2, 2] + actual = evaluated[..., 0:2, 2] + self.assertAllClose(res, actual) + res = linear_op[0:2, ..., 2] + actual = evaluated[0:2, ..., 2] + self.assertAllClose(res, actual) + + def test_getitem_tensor_index(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + # Non-batch case + if linear_op.ndimension() == 2: + index = (torch.tensor([0, 0, 1, 2]), torch.tensor([0, 1, 0, 2])) + res, actual = linear_op[index], evaluated[index] + self.assertAllClose(res, actual) + index = (torch.tensor([0, 0, 1, 2]), slice(None, None, None)) + res, actual = linear_operator.to_dense(linear_op[index]), evaluated[index] + self.assertAllClose(res, actual) + index = (slice(None, None, None), torch.tensor([0, 0, 1, 2])) + res, actual = linear_operator.to_dense(linear_op[index]), evaluated[index] + self.assertAllClose(res, actual) + index = (torch.tensor([0, 0, 1, 2]), Ellipsis) + res, actual = linear_operator.to_dense(linear_op[index]), evaluated[index] + self.assertAllClose(res, actual) + index = (Ellipsis, torch.tensor([0, 0, 1, 2])) + res, actual = linear_operator.to_dense(linear_op[index]), evaluated[index] + self.assertAllClose(res, actual) + index = (Ellipsis, torch.tensor([0, 0, 1, 2]), torch.tensor([0, 1, 0, 2])) + res, actual = linear_op[index], evaluated[index] + self.assertAllClose(res, actual) + + def test_getitem_broadcasted_tensor_index(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + # Non-batch case + if linear_op.ndimension() == 2: + index = ( + torch.tensor([0, 0, 1, 2]).unsqueeze(-1), + torch.tensor([0, 1, 0, 2]).unsqueeze(-2), + ) + res, actual = linear_op[index], evaluated[index] + self.assertAllClose(res, actual) + index = ( + Ellipsis, + torch.tensor([0, 0, 1, 2]).unsqueeze(-2), + torch.tensor([0, 1, 0, 2]).unsqueeze(-1), + ) + res, actual = linear_op[index], evaluated[index] + self.assertAllClose(res, actual) + + def test_permute(self): + linear_op = self.create_linear_op() + if linear_op.dim() >= 4: + evaluated = self.evaluate_linear_op(linear_op) + dims = torch.randperm(linear_op.dim() - 2).tolist() + + # Call using __torch_function__ + res = torch.permute(linear_op, (*dims, -2, -1)).to_dense() + actual = torch.permute(evaluated, (*dims, -2, -1)) + self.assertAllClose(res, actual) + + # Call using method + res = linear_op.permute(*dims, -2, -1).to_dense() + actual = torch.permute(evaluated, (*dims, -2, -1)) + self.assertAllClose(res, actual) + + def test_rmatmul_vec(self): + linear_op = self.create_linear_op() + + # We skip this test if we're dealing with batch LinearOperators + # They shouldn't multiply by a vec + if linear_op.ndimension() > 2: + return + + lhs = torch.randn(linear_op.size(-2)) + return self._test_rmatmul(lhs) + + def test_matmul_matrix(self): + linear_op = self.create_linear_op() + rhs = torch.randn(*linear_op.batch_shape, linear_op.size(-1), 4) + return self._test_matmul(rhs) + + def test_t_matmul_matrix(self): + with torch.no_grad(): + linear_op = self.create_linear_op() + rhs = torch.randn(*linear_op.batch_shape, linear_op.size(-2), 4) + linear_op_copy = torch.clone(linear_op) + evaluated = self.evaluate_linear_op(linear_op_copy) + rhs_evaluated = to_dense(rhs) + + # Test operator + res = linear_op._t_matmul(rhs) + actual = evaluated.mT.matmul(rhs_evaluated) + res_evaluated = to_dense(res) + self.assertAllClose(res_evaluated, actual) + + def test_rmatmul_matrix(self): + linear_op = self.create_linear_op() + lhs = torch.randn(*linear_op.batch_shape, 4, linear_op.size(-2)) + return self._test_rmatmul(lhs) + + def test_matmul_diag_matrix(self): + linear_op = self.create_linear_op() + diag = torch.rand(*linear_op.batch_shape, linear_op.size(-1)) + rhs = DiagLinearOperator(diag) + return self._test_matmul(rhs) + + def test_rsub(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + rhs = torch.randn(linear_op.shape) + # Test operator functionality + self.assertAllClose((rhs - linear_op).to_dense(), rhs - evaluated) + # Test __torch_function__ functionality + self.assertAllClose(torch.sub(rhs, linear_op).to_dense(), rhs - evaluated) + + def test_sub(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + rhs = torch.randn(linear_op.shape) + # Test operator functionality + self.assertAllClose((linear_op - rhs).to_dense(), evaluated - rhs) + # Test __torch_function__ functionality + self.assertAllClose(torch.sub(linear_op, rhs).to_dense(), evaluated - rhs) + + def test_sum(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + self.assertAllClose(torch.sum(linear_op, -1), torch.sum(evaluated, -1)) + self.assertAllClose(torch.sum(linear_op, -2), torch.sum(evaluated, -2)) + if linear_op.ndimension() > 2: + self.assertAllClose(torch.sum(linear_op, -3).to_dense(), torch.sum(evaluated, -3)) + if linear_op.ndimension() > 3: + self.assertAllClose(torch.sum(linear_op, -4).to_dense(), torch.sum(evaluated, -4)) diff --git a/linear_operator/test/linear_operator_test_case.py b/linear_operator/test/linear_operator_test_case.py index dc13cd85..fb7b163d 100644 --- a/linear_operator/test/linear_operator_test_case.py +++ b/linear_operator/test/linear_operator_test_case.py @@ -10,16 +10,16 @@ import torch import linear_operator -from linear_operator.operators import DenseLinearOperator, DiagLinearOperator, to_dense +from linear_operator.operators import DenseLinearOperator, to_dense from linear_operator.settings import linalg_dtypes from linear_operator.utils.errors import CachingError from linear_operator.utils.memoize import get_from_cache from ..utils.warnings import PerformanceWarning -from .base_test_case import BaseTestCase +from .linear_operator_core_test_case import CoreLinearOperatorTestCase -class RectangularLinearOperatorTestCase(BaseTestCase): +class RectangularLinearOperatorTestCase(CoreLinearOperatorTestCase): tolerances = { "matmul": {"rtol": 1e-3}, @@ -59,19 +59,25 @@ def _test_matmul(self, rhs): self.assertAllClose(to_dense(res), actual) def _test_rmatmul(self, lhs): + # Note. transpose_nonbatch is tested implicitly here because + # the base linear operator class defines + # def rmatmul(other): + # return self.mT.matmul(other.mT).mT linear_op = self.create_linear_op().detach().requires_grad_(True) linear_op_copy = torch.clone(linear_op).detach().requires_grad_(True) evaluated = self.evaluate_linear_op(linear_op_copy) # Test operator res = lhs @ linear_op + res_evaluated = to_dense(res) actual = lhs @ evaluated - self.assertAllClose(res, actual) + self.assertAllClose(res_evaluated, actual) # Test __torch_function__ res = torch.matmul(lhs, linear_op) + res_evaluated = to_dense(res) actual = torch.matmul(lhs, evaluated) - self.assertAllClose(res, actual) + self.assertAllClose(res_evaluated, actual) grad = torch.randn_like(res) res.backward(gradient=grad) @@ -81,107 +87,19 @@ def _test_rmatmul(self, lhs): self.assertAllClose(arg.grad, arg_copy.grad, **self.tolerances["matmul"]) def test_add(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - - rhs = torch.randn(linear_op.shape) - # Test operator functionality - a = (linear_op + rhs).to_dense() - b = evaluated + rhs - self.assertAllClose(a, b) - self.assertAllClose((linear_op + rhs).to_dense(), evaluated + rhs) - self.assertAllClose((rhs + linear_op).to_dense(), evaluated + rhs) - # Test __torch_function__ functionality - self.assertAllClose(torch.add(linear_op, rhs).to_dense(), evaluated + rhs) - self.assertAllClose(torch.add(rhs, linear_op).to_dense(), evaluated + rhs) - - rhs = torch.randn(linear_op.matrix_shape) - self.assertAllClose((linear_op + rhs).to_dense(), evaluated + rhs) + linear_op, evaluated = super().test_add() + # Test a batch of 2 rhs = torch.randn(2, *linear_op.shape) self.assertAllClose((linear_op + rhs).to_dense(), evaluated + rhs) - self.assertAllClose((linear_op + linear_op).to_dense(), evaluated * 2) - - def test_matmul_vec(self): - linear_op = self.create_linear_op() - - # We skip this test if we're dealing with batch LinearOperators - # They shouldn't multiply by a vec - if linear_op.ndimension() > 2: - return - - rhs = torch.randn(linear_op.size(-1)) - return self._test_matmul(rhs) - - def test_constant_mul(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - - # Test operator functionality - self.assertAllClose((linear_op * 5.0).to_dense(), evaluated * 5.0) - self.assertAllClose((linear_op * torch.tensor(5.0)).to_dense(), evaluated * 5.0) - self.assertAllClose((5.0 * linear_op).to_dense(), evaluated * 5.0) - self.assertAllClose((torch.tensor(5.0) * linear_op).to_dense(), evaluated * 5.0) - - # Test __torch_function__ functionality - self.assertAllClose(torch.mul(linear_op, torch.tensor(5.0)).to_dense(), evaluated * 5.0) - self.assertAllClose(torch.mul(torch.tensor(5.0), linear_op).to_dense(), evaluated * 5.0) - - def test_constant_mul_neg(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - self.assertAllClose((linear_op * -5.0).to_dense(), evaluated * -5.0) - - def test_constant_div(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - - # Test operator functionality - self.assertAllClose((linear_op / 5.0).to_dense(), evaluated / 5.0) - self.assertAllClose((linear_op / torch.tensor(5.0)).to_dense(), evaluated / 5.0) - - # Test __torch_function__ functionality - self.assertAllClose(torch.div(linear_op, torch.tensor(5.0)).to_dense(), evaluated / 5.0) - - def test_to_dense(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - self.assertAllClose(linear_op.to_dense(), evaluated) - def test_getitem(self): + super().test_getitem() linear_op = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_op) - # Non-batch case - if linear_op.ndimension() == 2: - res = linear_op[1] - actual = evaluated[1] - self.assertAllClose(res, actual) - res = linear_op[0:2].to_dense() - actual = evaluated[0:2] - self.assertAllClose(res, actual) - res = linear_op[:, 0:2].to_dense() - actual = evaluated[:, 0:2] - self.assertAllClose(res, actual) - res = linear_op[0:2, :].to_dense() - actual = evaluated[0:2, :] - self.assertAllClose(res, actual) - res = linear_op[..., 0:2].to_dense() - actual = evaluated[..., 0:2] - self.assertAllClose(res, actual) - res = linear_op[0:2, ...].to_dense() - actual = evaluated[0:2, ...] - self.assertAllClose(res, actual) - res = linear_op[..., 0:2, 2] - actual = evaluated[..., 0:2, 2] - self.assertAllClose(res, actual) - res = linear_op[0:2, ..., 2] - actual = evaluated[0:2, ..., 2] - self.assertAllClose(res, actual) - # Batch case - else: + if linear_op.ndimension() != 2: res = linear_op[1].to_dense() actual = evaluated[1] self.assertAllClose(res, actual) @@ -212,32 +130,12 @@ def test_getitem(self): self.assertAllClose(res, actual) def test_getitem_tensor_index(self): + super().test_getitem_tensor_index() linear_op = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_op) - # Non-batch case - if linear_op.ndimension() == 2: - index = (torch.tensor([0, 0, 1, 2]), torch.tensor([0, 1, 0, 2])) - res, actual = linear_op[index], evaluated[index] - self.assertAllClose(res, actual) - index = (torch.tensor([0, 0, 1, 2]), slice(None, None, None)) - res, actual = linear_operator.to_dense(linear_op[index]), evaluated[index] - self.assertAllClose(res, actual) - index = (slice(None, None, None), torch.tensor([0, 0, 1, 2])) - res, actual = linear_operator.to_dense(linear_op[index]), evaluated[index] - self.assertAllClose(res, actual) - index = (torch.tensor([0, 0, 1, 2]), Ellipsis) - res, actual = linear_operator.to_dense(linear_op[index]), evaluated[index] - self.assertAllClose(res, actual) - index = (Ellipsis, torch.tensor([0, 0, 1, 2])) - res, actual = linear_operator.to_dense(linear_op[index]), evaluated[index] - self.assertAllClose(res, actual) - index = (Ellipsis, torch.tensor([0, 0, 1, 2]), torch.tensor([0, 1, 0, 2])) - res, actual = linear_op[index], evaluated[index] - self.assertAllClose(res, actual) - # Batch case - else: + if linear_op.ndimension() != 2: for batch_index in product( [torch.tensor([0, 1, 1, 0]), slice(None, None, None)], repeat=(linear_op.dim() - 2), @@ -284,27 +182,12 @@ def test_getitem_tensor_index(self): self.assertAllClose(res, actual) def test_getitem_broadcasted_tensor_index(self): + super().test_getitem_broadcasted_tensor_index() linear_op = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_op) - # Non-batch case - if linear_op.ndimension() == 2: - index = ( - torch.tensor([0, 0, 1, 2]).unsqueeze(-1), - torch.tensor([0, 1, 0, 2]).unsqueeze(-2), - ) - res, actual = linear_op[index], evaluated[index] - self.assertAllClose(res, actual) - index = ( - Ellipsis, - torch.tensor([0, 0, 1, 2]).unsqueeze(-2), - torch.tensor([0, 1, 0, 2]).unsqueeze(-1), - ) - res, actual = linear_op[index], evaluated[index] - self.assertAllClose(res, actual) - # Batch case - else: + if linear_op.ndimension() != 2: for batch_index in product( [torch.tensor([0, 1, 1]).view(-1, 1, 1), slice(None, None, None)], repeat=(linear_op.dim() - 2), @@ -359,63 +242,6 @@ def test_getitem_broadcasted_tensor_index(self): ) self.assertAllClose(res, actual) - def test_permute(self): - linear_op = self.create_linear_op() - if linear_op.dim() >= 4: - evaluated = self.evaluate_linear_op(linear_op) - dims = torch.randperm(linear_op.dim() - 2).tolist() - - # Call using __torch_function__ - res = torch.permute(linear_op, (*dims, -2, -1)).to_dense() - actual = torch.permute(evaluated, (*dims, -2, -1)) - self.assertAllClose(res, actual) - - # Call using method - res = linear_op.permute(*dims, -2, -1).to_dense() - actual = torch.permute(evaluated, (*dims, -2, -1)) - self.assertAllClose(res, actual) - - def test_rmatmul_vec(self): - linear_op = self.create_linear_op() - - # We skip this test if we're dealing with batch LinearOperators - # They shouldn't multiply by a vec - if linear_op.ndimension() > 2: - return - - lhs = torch.randn(linear_op.size(-2)) - return self._test_rmatmul(lhs) - - def test_matmul_matrix(self): - linear_op = self.create_linear_op() - rhs = torch.randn(*linear_op.batch_shape, linear_op.size(-1), 4) - return self._test_matmul(rhs) - - def test_t_matmul_matrix(self): - with torch.no_grad(): - linear_op = self.create_linear_op() - rhs = torch.randn(*linear_op.batch_shape, linear_op.size(-2), 4) - linear_op_copy = torch.clone(linear_op) - evaluated = self.evaluate_linear_op(linear_op_copy) - rhs_evaluated = to_dense(rhs) - - # Test operator - res = linear_op._t_matmul(rhs) - actual = evaluated.mT.matmul(rhs_evaluated) - res_evaluated = to_dense(res) - self.assertAllClose(res_evaluated, actual) - - def test_rmatmul_matrix(self): - linear_op = self.create_linear_op() - lhs = torch.randn(*linear_op.batch_shape, 4, linear_op.size(-2)) - return self._test_rmatmul(lhs) - - def test_matmul_diag_matrix(self): - linear_op = self.create_linear_op() - diag = torch.rand(*linear_op.batch_shape, linear_op.size(-1)) - rhs = DiagLinearOperator(diag) - return self._test_matmul(rhs) - def test_matmul_matrix_broadcast(self): linear_op = self.create_linear_op() @@ -454,37 +280,6 @@ def test_rmatmul_matrix_broadcast(self): lhs = torch.randn(*batch_shape, 4, linear_op.size(-2)) self._test_rmatmul(lhs) - def test_rsub(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - - rhs = torch.randn(linear_op.shape) - # Test operator functionality - self.assertAllClose((rhs - linear_op).to_dense(), rhs - evaluated) - # Test __torch_function__ functionality - self.assertAllClose(torch.sub(rhs, linear_op).to_dense(), rhs - evaluated) - - def test_sub(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - - rhs = torch.randn(linear_op.shape) - # Test operator functionality - self.assertAllClose((linear_op - rhs).to_dense(), evaluated - rhs) - # Test __torch_function__ functionality - self.assertAllClose(torch.sub(linear_op, rhs).to_dense(), evaluated - rhs) - - def test_sum(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - - self.assertAllClose(torch.sum(linear_op, -1), torch.sum(evaluated, -1)) - self.assertAllClose(torch.sum(linear_op, -2), torch.sum(evaluated, -2)) - if linear_op.ndimension() > 2: - self.assertAllClose(torch.sum(linear_op, -3).to_dense(), torch.sum(evaluated, -3)) - if linear_op.ndimension() > 3: - self.assertAllClose(torch.sum(linear_op, -4).to_dense(), torch.sum(evaluated, -4)) - def test_squeeze_unsqueeze(self): linear_operator = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_operator) diff --git a/test/operators/test_block_matrix_linear_operator.py b/test/operators/test_block_matrix_linear_operator.py new file mode 100644 index 00000000..9a7da202 --- /dev/null +++ b/test/operators/test_block_matrix_linear_operator.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +import itertools +import unittest + +import torch + +from linear_operator.operators import BlockMatrixLinearOperator +from linear_operator.test.base_test_case import BaseTestCase +from linear_operator.test.linear_operator_core_test_case import CoreLinearOperatorTestCase + + +class TestBlockTensorSimple(BaseTestCase, unittest.TestCase): + def dense_to_4d(self, A_dense, T): + Ne = A_dense.size(0) // T + Me = A_dense.size(1) // T + A_blocks_est = A_dense.reshape(T, Ne, T, Me) + A_blocks_est = A_blocks_est.permute(0, 2, 1, 3) + return A_blocks_est + + def test_multiply(self): + T = 2 + N = 4 + M = 3 + K = 5 + + A = torch.randn(T, T, N, M) + B = torch.randn(T, T, M, K) + + A_blo = BlockMatrixLinearOperator.from_tensor(A, T) + B_blo = BlockMatrixLinearOperator.from_tensor(B, T) + res_AB = A_blo.matmul(B_blo) + res_dense_AB = res_AB.to_dense() + + A_dense = A.permute(0, 2, 1, 3).reshape(T * N, T * M) + B_dense = B.permute(0, 2, 1, 3).reshape(T * M, T * K) + expected = A_dense @ B_dense + self.assertAllClose(res_dense_AB, expected) + self.assertAllClose(A_dense, A_blo.to_dense()) + self.assertAllClose(B_dense, B_blo.to_dense()) + + # Convert dense format back to blocks and compare + A_blocks_est = self.dense_to_4d(A_dense, T) + self.assertAllClose(A, A_blocks_est) + + # Check Tensor multiplication + res_tensor_AB = A_blo.matmul(B_dense) + res_tensor_dense_AB = res_tensor_AB.to_dense() + self.assertAllClose(res_dense_AB, res_tensor_dense_AB) + + def test_sparse_multiply(self): + T, N, M = 2, 4, 3 + As = [torch.rand(N, M) for _ in range(T)] + Bs = [[torch.rand(M, M) for _ in range(T)] for _ in range(T)] + Cs = [torch.rand(N, N) for _ in range(T)] + # L = torch.rand(T, T) + + A_dense = torch.zeros((N * T, M * T)) # BlockDiag (non-square) + B_dense = torch.zeros((M * T, M * T)) # Dense + C_dense = torch.zeros((N * T, N * T)) # BlockDiag + # L_dense = torch.kron(L, torch.eye(N)) # Kroneker + + for t in range(T): + A_dense[N * t : N * (t + 1), M * t : M * (t + 1)] = As[t] + C_dense[N * t : N * (t + 1), N * t : N * (t + 1)] = Cs[t] + + for t1, t2 in itertools.product(range(T), range(T)): + B_dense[M * t1 : M * (t1 + 1), M * t2 : M * (t2 + 1)] = Bs[t1][t2] + + # Convert dense formats to blocks + A = self.dense_to_4d(A_dense, T) + B = self.dense_to_4d(B_dense, T) + + # A_blo will contain dense operators along the diagonal + Zero operators off diagonal + A_blo = BlockMatrixLinearOperator.from_tensor(A, T) + B_blo = BlockMatrixLinearOperator.from_tensor(B, T) + res_AB = A_blo.matmul(B_blo) + res_dense_AB = res_AB.to_dense() + + expected = A_dense @ B_dense + self.assertAllClose(res_dense_AB, expected) + self.assertAllClose(A_dense, A_blo.to_dense()) + self.assertAllClose(B_dense, B_blo.to_dense()) + + +class TestLinearOperatorBlockTensorLinearOperator(CoreLinearOperatorTestCase, unittest.TestCase): + seed = 0 + should_test_sample = False + T = 2 + N = M = 4 # Try a square for this set of tests + + A_dense = torch.eye(T * N) + A_blocks = A_dense.reshape(T, N, T, M).permute(0, 2, 1, 3) + + def create_linear_op(self): + A_blo = BlockMatrixLinearOperator.from_tensor(self.A_blocks, self.T) + return A_blo + + def evaluate_linear_op(self, linear_op): + D = linear_op.to_dense() + return D + + +if __name__ == "__main__": + unittest.main()