Skip to content

Commit

Permalink
Prepare for PyTorch 1.9
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano committed May 20, 2021
1 parent fe44ae2 commit b235311
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 25 deletions.
12 changes: 10 additions & 2 deletions examples/eigenvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
problem to the Sphere
"""
import torch
try:
from torch.linalg import eigvalsh
except ImportError:
from torch import symeig

def eigvalsh(X):
return symeig(X, eigenvectors=False).eigenvalues

from torch import nn
import geotorch

Expand All @@ -29,8 +37,8 @@ def forward(self, A):
A = torch.rand(N, N) # Uniform on [0, 1)
A = 0.5 * (A + A.T)

# Compare against diagonalization
max_eigenvalue = torch.symeig(A).eigenvalues.max()
# Compare against diagonalization (eigenvalues are returend in ascending order)
max_eigenvalue = eigvalsh(A)[-1]
print("Max eigenvalue: {:10.5f}".format(max_eigenvalue))

# Instantiate model and optimiser
Expand Down
14 changes: 7 additions & 7 deletions geotorch/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def almost_orthogonal(module, tensor_name="weight", lam=0.1, f="sin", triv="expm
>>> layer = nn.Linear(20, 30)
>>> geotorch.almost_orthogonal(layer, "weight", 0.5)
>>> S = torch.svd(layer.weight).S
>>> S = torch.linalg.svd(layer.weight).S
>>> all(S >= 0.5 and S <= 1.5)
True
Expand Down Expand Up @@ -252,7 +252,7 @@ def low_rank(module, tensor_name, rank, triv="expm"):
>>> layer = nn.Linear(20, 30)
>>> geotorch.low_rank(layer, "weight", 4)
>>> list(torch.svd(layer.weight).S > 1e-7).count(True) <= 4
>>> list(torch.linalg.svd(layer.weight).S > 1e-7).count(True) <= 4
True
Args:
Expand Down Expand Up @@ -284,7 +284,7 @@ def fixed_rank(module, tensor_name, rank, f="softplus", triv="expm"):
>>> layer = nn.Linear(20, 30)
>>> geotorch.fixed_rank(layer, "weight", 5)
>>> list(torch.svd(layer.weight).S > 1e-7).count(True)
>>> list(torch.linalg.svd(layer.weight).S > 1e-7).count(True)
5
Args:
Expand Down Expand Up @@ -367,7 +367,7 @@ def positive_definite(module, tensor_name="weight", f="softplus", triv="expm"):
>>> layer = nn.Linear(20, 20)
>>> geotorch.positive_definite(layer, "weight")
>>> (torch.symeig(layer.weight).eigenvalues > 0.0).all()
>>> (torch.linalg.eigvalsh(layer.weight) > 0.0).all()
tensor(True)
Args:
Expand Down Expand Up @@ -407,7 +407,7 @@ def positive_semidefinite(module, tensor_name="weight", triv="expm"):
>>> layer = nn.Linear(20, 20)
>>> geotorch.positive_semidefinite(layer, "weight")
>>> L = torch.symeig(layer.weight).eigenvalues
>>> L = torch.linalg.eigvalsh(layer.weight)
>>> L[L.abs() < 1e-7] = 0.0 # Round errors
>>> (L >= 0.0).all()
tensor(True)
Expand Down Expand Up @@ -439,7 +439,7 @@ def positive_semidefinite_low_rank(module, tensor_name, rank, triv="expm"):
>>> layer = nn.Linear(20, 20)
>>> geotorch.positive_semidefinite_low_rank(layer, "weight", 5)
>>> L = torch.symeig(layer.weight).eigenvalues
>>> L = torch.linalg.eigvalsh(layer.weight)
>>> L[L.abs() < 1e-7] = 0.0 # Round errors
>>> (L >= 0.0).all()
tensor(True)
Expand Down Expand Up @@ -478,7 +478,7 @@ def positive_semidefinite_fixed_rank(
>>> layer = nn.Linear(20, 20)
>>> geotorch.positive_semidefinite_fixed_rank(layer, "weight", 5)
>>> L = torch.symeig(layer.weight).eigenvalues
>>> L = torch.linalg.eigvalsh(layer.weight)
>>> L[L.abs() < 1e-7] = 0.0 # Round errors
>>> (L >= 0.0).all()
tensor(True)
Expand Down
2 changes: 1 addition & 1 deletion geotorch/fixedrank.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True, eps=5e-6):
"""
U, S, V = super().sample(factorized=True, init_=init_)
with torch.no_grad():
# S >= 0, as given by torch.symeig()
# S >= 0, as given by torch.linalg.eigvalsh()
S[S < eps] = eps
if factorized:
return U, S, V
Expand Down
21 changes: 16 additions & 5 deletions geotorch/lowrank.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
import torch
from functools import partial
try:
from torch.linalg import svd
svd = partial(svd, full_matrices=False)
except ImportError:
from torch import svd


from .product import ProductManifold
from .stiefel import Stiefel
from .reals import Rn
Expand Down Expand Up @@ -80,7 +88,7 @@ def frame_inv(self, X1, X2, X3):

def submersion_inv(self, X, check_in_manifold=True):
if isinstance(X, torch.Tensor):
U, S, V = X.svd()
U, S, V = svd(X)
if check_in_manifold and not self.in_manifold_singular_values(S):
raise InManifoldError(X, self)
else:
Expand Down Expand Up @@ -136,7 +144,7 @@ def in_manifold(self, X, eps=1e-5):
Args:
X (torch.Tensor or tuple): The matrix to be checked or a tuple containing
:math:`(U, \Sigma, V)` as returned by ``torch.svd`` or
:math:`(U, \Sigma, V)` as returned by ``torch.linalg.svd`` or
``self.sample(factorized=True)``.
eps (float): Optional. Threshold at which the singular values are
considered to be zero
Expand All @@ -152,7 +160,10 @@ def in_manifold(self, X, eps=1e-5):
X = X.transpose(-2, -1)
if X.size() != self.tensorial_size + (self.n, self.k):
return False
_, S, _ = X.svd(compute_uv=False)
try:
S = torch.linalg.svdvals(X)
except AttributeError:
S = svd(X).S
return self.in_manifold_singular_values(S, eps)

def project(self, X, factorized=True):
Expand All @@ -170,7 +181,7 @@ def project(self, X, factorized=True):
used to initialize a parametrized tensor.
Default: ``True``
"""
U, S, V = X.svd()
U, S, V = svd(X)
U, S, V = U[..., : self.rank], S[..., : self.rank], V[..., : self.rank]
if factorized:
return U, S, V
Expand Down Expand Up @@ -211,7 +222,7 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True):
*(self.tensorial_size + (self.n, self.k)), device=device, dtype=dtype
)
init_(X)
U, S, V = X.svd()
U, S, V = svd(X)
U, S, V = U[..., : self.rank], S[..., : self.rank], V[..., : self.rank]
if factorized:
return U, S, V
Expand Down
2 changes: 1 addition & 1 deletion geotorch/pssdfixedrank.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True, eps=5e-6):
"""
L, Q = super().sample(factorized=True, init_=init_)
with torch.no_grad():
# S >= 0, as given by torch.symeig()
# S >= 0, as given by torch.linalg.eigvalsh()
small = L < eps
L[small] = eps
if factorized:
Expand Down
6 changes: 5 additions & 1 deletion geotorch/so.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import math
import torch
from torch import nn
try:
from torch.linalg import qr
except ImportError:
from torch import qr

from .utils import _extra_repr
from .skew import Skew
Expand Down Expand Up @@ -186,7 +190,7 @@ def uniform_init_(tensor):
x = torch.empty_like(tensor).normal_(0, 1)
if transpose:
x.transpose_(-2, -1)
q, r = torch.qr(x)
q, r = qr(x)

# Make uniform (diag r >= 0)
d = r.diagonal(dim1=-2, dim2=-1).sign()
Expand Down
8 changes: 7 additions & 1 deletion geotorch/stiefel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import torch

try:
from torch.linalg import qr
except ImportError:
from torch import qr


from .utils import transpose, _extra_repr
from .so import SO, _has_orthonormal_columns

Expand Down Expand Up @@ -60,7 +66,7 @@ def right_inverse(self, X, check_in_manifold=True):
for _ in range(2):
N = N - X @ (X.transpose(-2, -1) @ N)
# And make it an orthonormal base of the image
N = N.qr().Q
N = qr(N).Q
X = torch.cat([X, N], dim=-1)
return super().right_inverse(X, check_in_manifold=False)[..., : self.k]

Expand Down
24 changes: 17 additions & 7 deletions geotorch/symmetric.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import torch
from torch import nn
from functools import partial
try:
from torch.linalg import eigh
from torch.linalg import eigvalsh
except ImportError:
from torch import symeig
eigh = partial(symeig, eigenvectors=True)

def eigvalsh(X):
return symeig(X, eigenvectors=False).eigenvalues

from .product import ProductManifold
from .stiefel import Stiefel
Expand All @@ -14,12 +24,12 @@
from .utils import _extra_repr


def _decreasing_symeig(X, eigenvectors):
def _decreasing_eigh(X, eigenvectors):
if eigenvectors:
L, Q = X.symeig(eigenvectors=True)
L, Q = eigh(X)
return L.flip(-1), Q.flip(-1)
else:
return X.symeig().eigenvalues.flip(-1)
return eigvalsh(X).flip(-1)


class Symmetric(nn.Module):
Expand Down Expand Up @@ -148,7 +158,7 @@ def frame_inv(self, X1, X2):
def submersion_inv(self, X, check_in_manifold=True):
if isinstance(X, torch.Tensor):
with torch.no_grad():
L, Q = _decreasing_symeig(X, eigenvectors=True)
L, Q = _decreasing_eigh(X, eigenvectors=True)
if check_in_manifold and not self.in_manifold_eigen(L):
raise InManifoldError(X, self)
else:
Expand Down Expand Up @@ -211,7 +221,7 @@ def in_manifold(self, X, eps=1e-6):
Args:
X (torch.Tensor or tuple): The matrix to be checked or a tuple
``(eigenvectors, eigenvalues)`` as returned by ``torch.symeig``
``(eigenvectors, eigenvalues)`` as returned by ``torch.linalg.eigh``
or ``self.sample(factorized=True)``.
eps (float): Optional. Threshold at which the singular values are
considered to be zero
Expand All @@ -226,7 +236,7 @@ def in_manifold(self, X, eps=1e-6):
if X.size() != size or not Symmetric.in_manifold(X, eps):
return False

L = _decreasing_symeig(X, eigenvectors=False)
L = _decreasing_eigh(X, eigenvectors=False)
return self.in_manifold_eigen(L, eps)

def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True):
Expand Down Expand Up @@ -268,7 +278,7 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True):
)
init_(X)
X = X @ X.transpose(-2, -1)
L, Q = _decreasing_symeig(X, eigenvectors=True)
L, Q = _decreasing_eigh(X, eigenvectors=True)
L = L[..., : self.rank]
Q = Q[..., : self.rank]
if factorized:
Expand Down

0 comments on commit b235311

Please sign in to comment.