diff --git a/examples/eigenvalue.py b/examples/eigenvalue.py index 4e4826a8..e80b9eaa 100644 --- a/examples/eigenvalue.py +++ b/examples/eigenvalue.py @@ -4,6 +4,14 @@ problem to the Sphere """ import torch +try: + from torch.linalg import eigvalsh +except ImportError: + from torch import symeig + + def eigvalsh(X): + return symeig(X, eigenvectors=False).eigenvalues + from torch import nn import geotorch @@ -29,8 +37,8 @@ def forward(self, A): A = torch.rand(N, N) # Uniform on [0, 1) A = 0.5 * (A + A.T) -# Compare against diagonalization -max_eigenvalue = torch.symeig(A).eigenvalues.max() +# Compare against diagonalization (eigenvalues are returend in ascending order) +max_eigenvalue = eigvalsh(A)[-1] print("Max eigenvalue: {:10.5f}".format(max_eigenvalue)) # Instantiate model and optimiser diff --git a/geotorch/constraints.py b/geotorch/constraints.py index 3384e295..2f62ae02 100644 --- a/geotorch/constraints.py +++ b/geotorch/constraints.py @@ -167,7 +167,7 @@ def almost_orthogonal(module, tensor_name="weight", lam=0.1, f="sin", triv="expm >>> layer = nn.Linear(20, 30) >>> geotorch.almost_orthogonal(layer, "weight", 0.5) - >>> S = torch.svd(layer.weight).S + >>> S = torch.linalg.svd(layer.weight).S >>> all(S >= 0.5 and S <= 1.5) True @@ -252,7 +252,7 @@ def low_rank(module, tensor_name, rank, triv="expm"): >>> layer = nn.Linear(20, 30) >>> geotorch.low_rank(layer, "weight", 4) - >>> list(torch.svd(layer.weight).S > 1e-7).count(True) <= 4 + >>> list(torch.linalg.svd(layer.weight).S > 1e-7).count(True) <= 4 True Args: @@ -284,7 +284,7 @@ def fixed_rank(module, tensor_name, rank, f="softplus", triv="expm"): >>> layer = nn.Linear(20, 30) >>> geotorch.fixed_rank(layer, "weight", 5) - >>> list(torch.svd(layer.weight).S > 1e-7).count(True) + >>> list(torch.linalg.svd(layer.weight).S > 1e-7).count(True) 5 Args: @@ -367,7 +367,7 @@ def positive_definite(module, tensor_name="weight", f="softplus", triv="expm"): >>> layer = nn.Linear(20, 20) >>> geotorch.positive_definite(layer, "weight") - >>> (torch.symeig(layer.weight).eigenvalues > 0.0).all() + >>> (torch.linalg.eigvalsh(layer.weight) > 0.0).all() tensor(True) Args: @@ -407,7 +407,7 @@ def positive_semidefinite(module, tensor_name="weight", triv="expm"): >>> layer = nn.Linear(20, 20) >>> geotorch.positive_semidefinite(layer, "weight") - >>> L = torch.symeig(layer.weight).eigenvalues + >>> L = torch.linalg.eigvalsh(layer.weight) >>> L[L.abs() < 1e-7] = 0.0 # Round errors >>> (L >= 0.0).all() tensor(True) @@ -439,7 +439,7 @@ def positive_semidefinite_low_rank(module, tensor_name, rank, triv="expm"): >>> layer = nn.Linear(20, 20) >>> geotorch.positive_semidefinite_low_rank(layer, "weight", 5) - >>> L = torch.symeig(layer.weight).eigenvalues + >>> L = torch.linalg.eigvalsh(layer.weight) >>> L[L.abs() < 1e-7] = 0.0 # Round errors >>> (L >= 0.0).all() tensor(True) @@ -478,7 +478,7 @@ def positive_semidefinite_fixed_rank( >>> layer = nn.Linear(20, 20) >>> geotorch.positive_semidefinite_fixed_rank(layer, "weight", 5) - >>> L = torch.symeig(layer.weight).eigenvalues + >>> L = torch.linalg.eigvalsh(layer.weight) >>> L[L.abs() < 1e-7] = 0.0 # Round errors >>> (L >= 0.0).all() tensor(True) diff --git a/geotorch/fixedrank.py b/geotorch/fixedrank.py index 0ee90722..af81f82f 100644 --- a/geotorch/fixedrank.py +++ b/geotorch/fixedrank.py @@ -120,7 +120,7 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True, eps=5e-6): """ U, S, V = super().sample(factorized=True, init_=init_) with torch.no_grad(): - # S >= 0, as given by torch.symeig() + # S >= 0, as given by torch.linalg.eigvalsh() S[S < eps] = eps if factorized: return U, S, V diff --git a/geotorch/lowrank.py b/geotorch/lowrank.py index 7160bff8..b2e92f85 100644 --- a/geotorch/lowrank.py +++ b/geotorch/lowrank.py @@ -1,4 +1,12 @@ import torch +from functools import partial +try: + from torch.linalg import svd + svd = partial(svd, full_matrices=False) +except ImportError: + from torch import svd + + from .product import ProductManifold from .stiefel import Stiefel from .reals import Rn @@ -80,7 +88,7 @@ def frame_inv(self, X1, X2, X3): def submersion_inv(self, X, check_in_manifold=True): if isinstance(X, torch.Tensor): - U, S, V = X.svd() + U, S, V = svd(X) if check_in_manifold and not self.in_manifold_singular_values(S): raise InManifoldError(X, self) else: @@ -136,7 +144,7 @@ def in_manifold(self, X, eps=1e-5): Args: X (torch.Tensor or tuple): The matrix to be checked or a tuple containing - :math:`(U, \Sigma, V)` as returned by ``torch.svd`` or + :math:`(U, \Sigma, V)` as returned by ``torch.linalg.svd`` or ``self.sample(factorized=True)``. eps (float): Optional. Threshold at which the singular values are considered to be zero @@ -152,7 +160,10 @@ def in_manifold(self, X, eps=1e-5): X = X.transpose(-2, -1) if X.size() != self.tensorial_size + (self.n, self.k): return False - _, S, _ = X.svd(compute_uv=False) + try: + S = torch.linalg.svdvals(X) + except AttributeError: + S = svd(X).S return self.in_manifold_singular_values(S, eps) def project(self, X, factorized=True): @@ -170,7 +181,7 @@ def project(self, X, factorized=True): used to initialize a parametrized tensor. Default: ``True`` """ - U, S, V = X.svd() + U, S, V = svd(X) U, S, V = U[..., : self.rank], S[..., : self.rank], V[..., : self.rank] if factorized: return U, S, V @@ -211,7 +222,7 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True): *(self.tensorial_size + (self.n, self.k)), device=device, dtype=dtype ) init_(X) - U, S, V = X.svd() + U, S, V = svd(X) U, S, V = U[..., : self.rank], S[..., : self.rank], V[..., : self.rank] if factorized: return U, S, V diff --git a/geotorch/pssdfixedrank.py b/geotorch/pssdfixedrank.py index 02b224d9..ef6ed7a6 100644 --- a/geotorch/pssdfixedrank.py +++ b/geotorch/pssdfixedrank.py @@ -101,7 +101,7 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True, eps=5e-6): """ L, Q = super().sample(factorized=True, init_=init_) with torch.no_grad(): - # S >= 0, as given by torch.symeig() + # S >= 0, as given by torch.linalg.eigvalsh() small = L < eps L[small] = eps if factorized: diff --git a/geotorch/so.py b/geotorch/so.py index c6890e6f..8b6e471a 100644 --- a/geotorch/so.py +++ b/geotorch/so.py @@ -1,6 +1,10 @@ import math import torch from torch import nn +try: + from torch.linalg import qr +except ImportError: + from torch import qr from .utils import _extra_repr from .skew import Skew @@ -186,7 +190,7 @@ def uniform_init_(tensor): x = torch.empty_like(tensor).normal_(0, 1) if transpose: x.transpose_(-2, -1) - q, r = torch.qr(x) + q, r = qr(x) # Make uniform (diag r >= 0) d = r.diagonal(dim1=-2, dim2=-1).sign() diff --git a/geotorch/stiefel.py b/geotorch/stiefel.py index 2d4b5884..a13cf367 100644 --- a/geotorch/stiefel.py +++ b/geotorch/stiefel.py @@ -1,5 +1,11 @@ import torch +try: + from torch.linalg import qr +except ImportError: + from torch import qr + + from .utils import transpose, _extra_repr from .so import SO, _has_orthonormal_columns @@ -60,7 +66,7 @@ def right_inverse(self, X, check_in_manifold=True): for _ in range(2): N = N - X @ (X.transpose(-2, -1) @ N) # And make it an orthonormal base of the image - N = N.qr().Q + N = qr(N).Q X = torch.cat([X, N], dim=-1) return super().right_inverse(X, check_in_manifold=False)[..., : self.k] diff --git a/geotorch/symmetric.py b/geotorch/symmetric.py index 840bd2ae..92662c49 100644 --- a/geotorch/symmetric.py +++ b/geotorch/symmetric.py @@ -1,5 +1,15 @@ import torch from torch import nn +from functools import partial +try: + from torch.linalg import eigh + from torch.linalg import eigvalsh +except ImportError: + from torch import symeig + eigh = partial(symeig, eigenvectors=True) + + def eigvalsh(X): + return symeig(X, eigenvectors=False).eigenvalues from .product import ProductManifold from .stiefel import Stiefel @@ -14,12 +24,12 @@ from .utils import _extra_repr -def _decreasing_symeig(X, eigenvectors): +def _decreasing_eigh(X, eigenvectors): if eigenvectors: - L, Q = X.symeig(eigenvectors=True) + L, Q = eigh(X) return L.flip(-1), Q.flip(-1) else: - return X.symeig().eigenvalues.flip(-1) + return eigvalsh(X).flip(-1) class Symmetric(nn.Module): @@ -148,7 +158,7 @@ def frame_inv(self, X1, X2): def submersion_inv(self, X, check_in_manifold=True): if isinstance(X, torch.Tensor): with torch.no_grad(): - L, Q = _decreasing_symeig(X, eigenvectors=True) + L, Q = _decreasing_eigh(X, eigenvectors=True) if check_in_manifold and not self.in_manifold_eigen(L): raise InManifoldError(X, self) else: @@ -211,7 +221,7 @@ def in_manifold(self, X, eps=1e-6): Args: X (torch.Tensor or tuple): The matrix to be checked or a tuple - ``(eigenvectors, eigenvalues)`` as returned by ``torch.symeig`` + ``(eigenvectors, eigenvalues)`` as returned by ``torch.linalg.eigh`` or ``self.sample(factorized=True)``. eps (float): Optional. Threshold at which the singular values are considered to be zero @@ -226,7 +236,7 @@ def in_manifold(self, X, eps=1e-6): if X.size() != size or not Symmetric.in_manifold(X, eps): return False - L = _decreasing_symeig(X, eigenvectors=False) + L = _decreasing_eigh(X, eigenvectors=False) return self.in_manifold_eigen(L, eps) def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True): @@ -268,7 +278,7 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True): ) init_(X) X = X @ X.transpose(-2, -1) - L, Q = _decreasing_symeig(X, eigenvectors=True) + L, Q = _decreasing_eigh(X, eigenvectors=True) L = L[..., : self.rank] Q = Q[..., : self.rank] if factorized: