Skip to content

Commit

Permalink
Merge pull request #85 from cornellius-gp/absolute_imports
Browse files Browse the repository at this point in the history
Use absolute imports
  • Loading branch information
Balandat authored Nov 20, 2023
2 parents 5496242 + 7858888 commit 6eff871
Show file tree
Hide file tree
Showing 61 changed files with 397 additions and 355 deletions.
6 changes: 3 additions & 3 deletions linear_operator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
from . import beta_features, operators, settings, utils
from .functions import (
from linear_operator import beta_features, operators, settings, utils
from linear_operator.functions import (
add_diagonal,
add_jitter,
diagonalization,
Expand All @@ -13,7 +13,7 @@
solve,
sqrt_inv_matmul,
)
from .operators import LinearOperator, to_dense, to_linear_operator
from linear_operator.operators import LinearOperator, to_dense, to_linear_operator

# Read version number as written by setuptools_scm
try:
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/beta_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings

from .settings import _feature_flag
from linear_operator.settings import _feature_flag


class _moved_beta_feature(object):
Expand Down
20 changes: 10 additions & 10 deletions linear_operator/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch

from ._dsmm import DSMM
from linear_operator.functions._dsmm import DSMM

LinearOperatorType = Any # Want this to be "LinearOperator" but runtime type checker can't handle

Expand All @@ -23,7 +23,7 @@ def add_diagonal(input: Anysor, diag: torch.Tensor) -> LinearOperatorType:
:return: :math:`\mathbf A + \text{diag}(\mathbf d)`, where :math:`\mathbf A` is the linear operator
and :math:`\mathbf d` is the diagonal component
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).add_diagonal(diag)

Expand Down Expand Up @@ -61,7 +61,7 @@ def diagonalization(
based on size if not specified.
:return: eigenvalues and eigenvectors representing the diagonalization.
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).diagonalization(method=method)

Expand Down Expand Up @@ -105,7 +105,7 @@ def inv_quad(input: Anysor, inv_quad_rhs: torch.Tensor, reduce_inv_quad: bool =
:returns: The inverse quadratic term.
If `reduce_inv_quad=True`, the inverse quadratic term is of shape (...). Otherwise, it is (... x M).
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).inv_quad(inv_quad_rhs, reduce_inv_quad=reduce_inv_quad)

Expand All @@ -127,7 +127,7 @@ def inv_quad_logdet(
:returns: The inverse quadratic term (or None), and the logdet term (or None).
If `reduce_inv_quad=True`, the inverse quadratic term is of shape (...). Otherwise, it is (... x M).
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).inv_quad_logdet(inv_quad_rhs, logdet, reduce_inv_quad=reduce_inv_quad)

Expand Down Expand Up @@ -156,7 +156,7 @@ def pivoted_cholesky(
.. _Harbrecht et al., 2012:
https://www.sciencedirect.com/science/article/pii/S0168927411001814
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).pivoted_cholesky(rank=rank, error_tol=error_tol, return_pivots=return_pivots)

Expand All @@ -173,7 +173,7 @@ def root_decomposition(input: Anysor, method: Optional[str] = None) -> LinearOpe
"cholesky", "lanczos", "symeig", "pivoted_cholesky", or "svd".
:return: A tensor :math:`\mathbf R` such that :math:`\mathbf R \mathbf R^\top \approx \mathbf A`.
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).root_decomposition(method=method)

Expand All @@ -199,7 +199,7 @@ def root_inv_decomposition(
:param method: Root decomposition method to use (symeig, diagonalization, lanczos, or cholesky).
:return: A tensor :math:`\mathbf R` such that :math:`\mathbf R \mathbf R^\top \approx \mathbf A^{-1}`.
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).root_inv_decomposition(
initial_vectors=initial_vectors, test_vectors=test_vectors, method=method
Expand Down Expand Up @@ -235,7 +235,7 @@ def solve(input: Anysor, rhs: torch.Tensor, lhs: Optional[torch.Tensor] = None)
:param lhs: :math:`\mathbf L` - the left hand side
:return: :math:`\mathbf A^{-1} \mathbf R` or :math:`\mathbf L \mathbf A^{-1} \mathbf R`.
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).solve(right_tensor=rhs, left_tensor=lhs)

Expand Down Expand Up @@ -268,7 +268,7 @@ def sqrt_inv_matmul(
:param lhs: :math:`\mathbf L` - the left hand side
:return: :math:`\mathbf A^{-1/2} \mathbf R` or :math:`\mathbf L \mathbf A^{-1/2} \mathbf R`.
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

return to_linear_operator(input).sqrt_inv_matmul(rhs=rhs, lhs=lhs)

Expand Down
4 changes: 2 additions & 2 deletions linear_operator/functions/_diagonalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
from torch.autograd import Function

from .. import settings
from ..utils import lanczos
from linear_operator import settings
from linear_operator.utils import lanczos


class Diagonalization(Function):
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/functions/_dsmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch.autograd import Function

from ..utils.sparse import bdsmm
from linear_operator.utils.sparse import bdsmm


class DSMM(Function):
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/functions/_inv_quad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch.autograd import Function

from .. import settings
from linear_operator import settings


def _solve(linear_op, rhs):
Expand Down
6 changes: 3 additions & 3 deletions linear_operator/functions/_inv_quad_logdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import torch
from torch.autograd import Function

from .. import settings
from ..utils.lanczos import lanczos_tridiag_to_diag
from ..utils.stochastic_lq import StochasticLQ
from linear_operator import settings
from linear_operator.utils.lanczos import lanczos_tridiag_to_diag
from linear_operator.utils.stochastic_lq import StochasticLQ


class InvQuadLogdet(Function):
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/functions/_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch.autograd import Function

from .. import settings
from linear_operator import settings


class Matmul(Function):
Expand Down
6 changes: 3 additions & 3 deletions linear_operator/functions/_pivoted_cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import torch
from torch.autograd import Function

from .. import settings
from ..utils.cholesky import psd_safe_cholesky
from ..utils.permutation import apply_permutation, inverse_permutation
from linear_operator import settings
from linear_operator.utils.cholesky import psd_safe_cholesky
from linear_operator.utils.permutation import apply_permutation, inverse_permutation


class PivotedCholesky(Function):
Expand Down
6 changes: 3 additions & 3 deletions linear_operator/functions/_root_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
from torch.autograd import Function

from .. import settings
from ..utils import lanczos
from linear_operator import settings
from linear_operator.utils import lanczos


class RootDecomposition(Function):
Expand All @@ -29,7 +29,7 @@ def forward(
:return: :attr:`R`, such that :math:`R R^T \approx A`, and :attr:`R_inv`, such that
:math:`R_{inv} R_{inv}^T \approx A^{-1}` (will only be populated if self.inverse = True)
"""
from ..operators import to_linear_operator
from linear_operator.operators import to_linear_operator

ctx.representation_tree = representation_tree
ctx.device = device
Expand Down
4 changes: 2 additions & 2 deletions linear_operator/functions/_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import torch
from torch.autograd import Function

from .. import settings
from linear_operator import settings


def _solve(linear_op, rhs):
from ..operators import CholLinearOperator, TriangularLinearOperator
from linear_operator.operators import CholLinearOperator, TriangularLinearOperator

if isinstance(linear_op, (CholLinearOperator, TriangularLinearOperator)):
# May want to do this for some KroneckerProductLinearOperators and possibly
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/functions/_sqrt_inv_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch.autograd import Function

from .. import settings, utils
from linear_operator import settings, utils


class SqrtInvMatmul(Function):
Expand Down
67 changes: 36 additions & 31 deletions linear_operator/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,45 @@
#!/usr/bin/env python3

from ._linear_operator import LinearOperator, to_dense
from .added_diag_linear_operator import AddedDiagLinearOperator
from .batch_repeat_linear_operator import BatchRepeatLinearOperator
from .block_diag_linear_operator import BlockDiagLinearOperator
from .block_interleaved_linear_operator import BlockInterleavedLinearOperator
from .block_linear_operator import BlockLinearOperator
from .cat_linear_operator import cat, CatLinearOperator
from .chol_linear_operator import CholLinearOperator
from .constant_mul_linear_operator import ConstantMulLinearOperator
from .dense_linear_operator import DenseLinearOperator, to_linear_operator
from .diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator
from .identity_linear_operator import IdentityLinearOperator
from .interpolated_linear_operator import InterpolatedLinearOperator
from .keops_linear_operator import KeOpsLinearOperator
from .kernel_linear_operator import KernelLinearOperator
from .kronecker_product_added_diag_linear_operator import KroneckerProductAddedDiagLinearOperator
from .kronecker_product_linear_operator import (
from linear_operator.operators._linear_operator import LinearOperator, to_dense
from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator
from linear_operator.operators.batch_repeat_linear_operator import BatchRepeatLinearOperator
from linear_operator.operators.block_diag_linear_operator import BlockDiagLinearOperator
from linear_operator.operators.block_interleaved_linear_operator import BlockInterleavedLinearOperator
from linear_operator.operators.block_linear_operator import BlockLinearOperator
from linear_operator.operators.cat_linear_operator import cat, CatLinearOperator
from linear_operator.operators.chol_linear_operator import CholLinearOperator
from linear_operator.operators.constant_mul_linear_operator import ConstantMulLinearOperator
from linear_operator.operators.dense_linear_operator import DenseLinearOperator, to_linear_operator
from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator
from linear_operator.operators.identity_linear_operator import IdentityLinearOperator
from linear_operator.operators.interpolated_linear_operator import InterpolatedLinearOperator
from linear_operator.operators.keops_linear_operator import KeOpsLinearOperator
from linear_operator.operators.kernel_linear_operator import KernelLinearOperator
from linear_operator.operators.kronecker_product_added_diag_linear_operator import (
KroneckerProductAddedDiagLinearOperator,
)
from linear_operator.operators.kronecker_product_linear_operator import (
KroneckerProductDiagLinearOperator,
KroneckerProductLinearOperator,
KroneckerProductTriangularLinearOperator,
)
from .low_rank_root_added_diag_linear_operator import LowRankRootAddedDiagLinearOperator
from .low_rank_root_linear_operator import LowRankRootLinearOperator
from .masked_linear_operator import MaskedLinearOperator
from .matmul_linear_operator import MatmulLinearOperator
from .mul_linear_operator import MulLinearOperator
from .permutation_linear_operator import PermutationLinearOperator, TransposePermutationLinearOperator
from .psd_sum_linear_operator import PsdSumLinearOperator
from .root_linear_operator import RootLinearOperator
from .sum_batch_linear_operator import SumBatchLinearOperator
from .sum_kronecker_linear_operator import SumKroneckerLinearOperator
from .sum_linear_operator import SumLinearOperator
from .toeplitz_linear_operator import ToeplitzLinearOperator
from .triangular_linear_operator import TriangularLinearOperator
from .zero_linear_operator import ZeroLinearOperator
from linear_operator.operators.low_rank_root_added_diag_linear_operator import LowRankRootAddedDiagLinearOperator
from linear_operator.operators.low_rank_root_linear_operator import LowRankRootLinearOperator
from linear_operator.operators.masked_linear_operator import MaskedLinearOperator
from linear_operator.operators.matmul_linear_operator import MatmulLinearOperator
from linear_operator.operators.mul_linear_operator import MulLinearOperator
from linear_operator.operators.permutation_linear_operator import (
PermutationLinearOperator,
TransposePermutationLinearOperator,
)
from linear_operator.operators.psd_sum_linear_operator import PsdSumLinearOperator
from linear_operator.operators.root_linear_operator import RootLinearOperator
from linear_operator.operators.sum_batch_linear_operator import SumBatchLinearOperator
from linear_operator.operators.sum_kronecker_linear_operator import SumKroneckerLinearOperator
from linear_operator.operators.sum_linear_operator import SumLinearOperator
from linear_operator.operators.toeplitz_linear_operator import ToeplitzLinearOperator
from linear_operator.operators.triangular_linear_operator import TriangularLinearOperator
from linear_operator.operators.zero_linear_operator import ZeroLinearOperator

__all__ = [
"to_dense",
Expand Down
Loading

0 comments on commit 6eff871

Please sign in to comment.