Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moise/non sym toeplitz #83

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
134 changes: 113 additions & 21 deletions linear_operator/operators/toeplitz_linear_operator.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,169 @@
#!/usr/bin/env python3
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import torch
from jaxtyping import Float
from torch import Tensor

from ..utils.toeplitz import sym_toeplitz_derivative_quadratic_form, sym_toeplitz_matmul
from ..utils.errors import NotPSDError
from ..utils.toeplitz import toeplitz_derivative_quadratic_form, toeplitz_matmul, toeplitz_solve_ld, toeplitz_inverse
from ._linear_operator import IndexType, LinearOperator


class ToeplitzLinearOperator(LinearOperator):
def __init__(self, column):
def __init__(self, column, row=None):
"""
Construct a Toeplitz matrix.
The Toeplitz matrix has constant diagonals, with `column` as its first
column and `row` as its first row. If `row` is not given,
`row == conjugate(column)` is assumed.

Args:
:attr: `column` (Tensor)
If `column` is a 1D Tensor of length `n`, this represents a
First column of the matrix. If `column` is a 1D Tensor of length `n`, this represents a
Toeplitz matrix with `column` as its first column.
If `column` is `b_1 x b_2 x ... x b_k x n`, then this represents a batch
`b_1 x b_2 x ... x b_k` of Toeplitz matrices.
:attr: `row` (Tensor)
First row of the matrix If `row` is a 1D Tensor of length `n`, this represents a
Toeplitz matrix with `row` as its row column.
`row` tensor must have the same size as `column`, with `column[...,0]`
equal to `row[...,0]`.
If `row` is `None` or is not supplied, assumes `row == conjugate(column)`.
If `row[0]` is real, the result is a Hermitian matrix.
"""
super(ToeplitzLinearOperator, self).__init__(column)
self.column = column
if row is None:
super(ToeplitzLinearOperator, self).__init__(column)
self.sym = True
myrow = column.conj()
myrow.data[...,0] = column[...,0]
self.row = myrow
else:
super(ToeplitzLinearOperator, self).__init__(column, row)
self.sym = False
self.row = row
if torch.any(row[...,0] != column[...,0]):
raise ValueError("The first elements in column does not match the first values in row")
if torch.allclose(row, column.conj()):
self.sym = True

def _cholesky(
self: Float[LinearOperator, "*batch N N"], upper: Optional[bool] = False
) -> Float[LinearOperator, "*batch N N"]:
if not self.sym:
#Cholesky decompositions are for Hermitian matrix
raise NotPSDError("Non-symmetric ToeplitzLinearOperator does not allow a Cholesky decomposition")
return super(ToeplitzLinearOperator, self)._cholesky(upper)

def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]:
diag_term = self.column[..., 0]
size = min(self.column.size(-1), self.row.size(-1))
if self.column.ndimension() > 1:
diag_term = diag_term.unsqueeze(-1)
return diag_term.expand(*self.column.size())
return diag_term.expand(*self.column.size()[:-1], size)

def _expand_batch(
self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]]
) -> Float[LinearOperator, "... M N"]:
return self.__class__(self.column.expand(*batch_shape, self.column.size(-1)))
#return self.__class__(self.column.expand(*batch_shape, self.column.size(-1)))
if self.sym:
return self.__class__(self.column.expand(*batch_shape, self.column.size(-1)))
else:
return self.__class__(
self.column.expand(*batch_shape, self.column.size(-1)),
self.row.expand(*batch_shape, self.row.size(-1)),
)

def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor:
toeplitz_indices = (row_index - col_index).fmod(self.size(-1)).abs().long()
return self.column[(*batch_indices, toeplitz_indices)]
res = torch.where(row_index > col_index, self.column[(*batch_indices, toeplitz_indices)], self.row[(*batch_indices, toeplitz_indices)])
return res #self.column[(*batch_indices, toeplitz_indices)]

def _matmul(
self: Float[LinearOperator, "*batch M N"],
rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]],
) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]:
return sym_toeplitz_matmul(self.column, rhs)
return toeplitz_matmul(self.column, self.row, rhs)

def _t_matmul(
self: Float[LinearOperator, "*batch M N"],
rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]],
) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]:
# Matrix is symmetric
return self._matmul(rhs)

return toeplitz_matmul(self.row, self.column, rhs)

def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]:
if left_vecs.ndimension() == 1:
left_vecs = left_vecs.unsqueeze(1)
right_vecs = right_vecs.unsqueeze(1)

res = sym_toeplitz_derivative_quadratic_form(left_vecs, right_vecs)
res_c, res_r = toeplitz_derivative_quadratic_form(left_vecs, right_vecs)

# Collapse any expanded broadcast dimensions
if res.dim() > self.column.dim():
res = res.view(-1, *self.column.shape).sum(0)

return (res,)
if res_c.dim() > self.column.dim():
res_c = res_c.view(-1, *self.column.shape).sum(0)
if res_r.dim() > self.row.dim():
res_r = res_r.view(-1, *self.row.shape).sum(0)

res_r[...,0] = 0. #set it to zero as already in res_c[...,0]

if self.sym:
return (res_c + res_r,)
else:
return (res_c, res_r,)

def _root_decomposition(
self: Float[LinearOperator, "... N N"]
) -> Union[Float[torch.Tensor, "... N N"], Float[LinearOperator, "... N N"]]:
if not self.sym:
raise NotPSDError("Non-symmetric ToeplitzLinearOperator does not allow a root decomposition")
return super(ToeplitzLinearOperator, self)._root_decomposition()

def _root_inv_decomposition(
self: Float[LinearOperator, "*batch N N"],
initial_vectors: Optional[torch.Tensor] = None,
test_vectors: Optional[torch.Tensor] = None,
) -> Union[Float[LinearOperator, "... N N"], Float[Tensor, "... N N"]]:
if not self.sym:
raise NotPSDError("Non-symmetric ToeplitzLinearOperator does not allow an inverse root decomposition")
return super(ToeplitzLinearOperator, self)._root_inv_decomposition(initial_vectors, test_vectors)

def _size(self) -> torch.Size:
return torch.Size((*self.column.shape, self.column.size(-1)))
return torch.Size((*self.row.shape, self.column.size(-1)))

def solve(
self: Float[LinearOperator, "... N N"],
right_tensor: Union[Float[Tensor, "... N P"], Float[Tensor, " N"]],
left_tensor: Optional[Float[Tensor, "... O N"]] = None,
) -> Union[Float[Tensor, "... N P"], Float[Tensor, "... N"], Float[Tensor, "... O P"], Float[Tensor, "... O"]]:
squeeze = False
if right_tensor.dim() == 1:
rhs_ = right_tensor.unsqueeze(-1)
squeeze = True
else:
rhs_ = right_tensor
res = toeplitz_solve_ld(self.column, self.row, rhs_)
if squeeze:
res = res.squeeze(-1)
if left_tensor is not None:
res = left_tensor @ res
return res

def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]:
return ToeplitzLinearOperator(self.column)
if self.sym:
return ToeplitzLinearOperator(self.column)
else:
myrow = torch.cat([self.column[...,0].unsqueeze(-1), self.row[...,1:]], dim=-1)
mycol = torch.clone(self.column)
myrow.data[...,0] = mycol[...,0]
return ToeplitzLinearOperator(myrow, mycol)

def add_jitter(
self: Float[LinearOperator, "*batch N N"], jitter_val: float = 1e-3
) -> Float[LinearOperator, "*batch N N"]:
jitter = torch.zeros_like(self.column)
jitter.narrow(-1, 0, 1).fill_(jitter_val)
return ToeplitzLinearOperator(self.column.add(jitter))
if self.sym:
return ToeplitzLinearOperator(self.column.add(jitter))
else:
return ToeplitzLinearOperator(self.column.add(jitter), self.row.add(jitter))
1 change: 0 additions & 1 deletion linear_operator/test/linear_operator_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,6 @@ def test_add_jitter(self):
self.assertAllClose(res, actual)

def test_add_low_rank(self):
linear_op = self.create_linear_op()
linear_op = self.create_linear_op()
evaluated = self.evaluate_linear_op(linear_op)
new_rows = torch.randn(*linear_op.shape[:-1], 3)
Expand Down
162 changes: 160 additions & 2 deletions linear_operator/utils/toeplitz.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ def toeplitz_matmul(toeplitz_column, toeplitz_row, tensor):
Returns:
- tensor (n x p or b x n x p) - The result of the matrix multiply T * M.
"""
if toeplitz_column.size() != toeplitz_row.size():
raise RuntimeError("c and r should have the same length (Toeplitz matrices are necessarily square).")
toeplitz_input_check(toeplitz_column, toeplitz_row)

toeplitz_shape = torch.Size((*toeplitz_column.shape, toeplitz_row.size(-1)))
output_shape = broadcasting._matmul_broadcast_shape(toeplitz_shape, tensor.shape)
Expand Down Expand Up @@ -160,6 +159,104 @@ def sym_toeplitz_matmul(toeplitz_column, tensor):
return toeplitz_matmul(toeplitz_column, toeplitz_column, tensor)


def sym_toeplitz_solve_ld(toeplitz_column, right_vectors):
"""
Solve the linear system Tx=b where T is a symmetric Toeplitz matrix and b the right
hand side of the equation using the Levinson-Durbin recursion, which run in O(n^2) time.
Args:
- toeplitz_column (vector n or b x n) - First column of the Toeplitz matrix T.
- toeplitz_row (vector n or b x n) - First row of the Toeplitz matrix T.
- right_vectors (matrix n x p or b x n x p) - Right hand side in T x = b
Returns:
- tensor (n x p or b x n x p) - The solution to the system T x = b.
Shape of return matches shape of b.
"""
return toeplitz_solve_ld(toeplitz_column, toeplitz_column, right_vectors)


def toeplitz_solve_ld(toeplitz_column, toeplitz_row, right_vectors):
"""
Solve the linear system Tx=b where T is a general Toeplitz matrix and b the right
hand side of the equation. Use the Levinson-Durbin recursion, which run in O(n^2) time,
but may exhibit numerical stability issues.
Args:
- toeplitz_column (vector n or b x n) - First column of the Toeplitz matrix T.
- toeplitz_row (vector n or b x n) - First row of the Toeplitz matrix T.
- right_vectors (matrix n x p or b x n x p) - Right hand side in T x = b
Returns:
- tensor (n x p or b x n x p) - The solution to the system T x = b.
Shape of return matches shape of b.
"""
# check input
toeplitz_input_check(toeplitz_column, toeplitz_row)
if right_vectors.ndimension() == 1:
if toeplitz_row.shape[-1] != len(right_vectors):
raise RuntimeError(f"Incompatible size betwen the Toeplitz matrix and the right vector: {toeplitz_column.shape} and {right_vectors.shape}")
else:
if toeplitz_row.shape[-1] != right_vectors.size(-2):
raise RuntimeError(f"Incompatible size betwen the Toeplitz matrix and the right vector: {toeplitz_column.shape} and {right_vectors.shape}")

output_shape = torch.broadcast_shapes(toeplitz_row.shape, right_vectors.shape[:-1])
broadcasted_t_shape = output_shape#[:-1] if right_vectors.dim() > 1 else output_shape
unsqueezed_vec = False
if right_vectors.ndimension() == 1:
right_vectors = right_vectors.unsqueeze(-1)
unsqueezed_vec = True
toeplitz_column = toeplitz_column.expand(*broadcasted_t_shape)
toeplitz_row = toeplitz_row.expand(*broadcasted_t_shape)
N = toeplitz_column.size(-1)

# f = forward vector , b = backward vector
# xi = vector at iterator i, xim = vector at iteration i-1
flipped_toeplitz_column = toeplitz_column[..., 1:].flip(dims=(-1,))
xi = torch.zeros_like(right_vectors).expand(*broadcasted_t_shape, right_vectors.shape[-1]).clone()
fi = torch.zeros_like(xi)
bi = torch.zeros_like(xi)
bim = torch.zeros_like(xi)

# iteration 0
fi[...,0,:] = 1/toeplitz_column[...,0,None]
bi[...,N-1,:] = 1/toeplitz_column[...,0,None]
xi[...,0,:] = right_vectors[...,0,:]/toeplitz_column[...,0,None]
if N == 1: return xi

for i in range(1,N):
#update
bim = bi.clone()
#compute the new forward and backward vector
efi = torch.matmul(flipped_toeplitz_column[...,N-i-1:N-1,None].mT, fi.clone()[...,:i,:])
ebi = torch.matmul(toeplitz_row[...,1:i+1,None].mT, bim[...,N-i:,:])
coeff = 1/(1-ebi*efi)
bi[...,N-i-1:,:] = coeff * (bim[...,N-i-1:,:] - ebi * fi.clone()[...,:i+1,:])
fi[...,:i+1,:] = coeff * (fi[...,:i+1,:] - efi * bim[...,N-i-1:,:])
#update solution
exim = torch.matmul(flipped_toeplitz_column[...,N-i-1:N-1,None].mT, xi.clone()[...,:i,:])
xi[...,:i+1,:] = xi[...,:i+1,:] + bi.clone()[...,N-i-1:,:] * (right_vectors[...,i,:,None].mT - exim)

if unsqueezed_vec == 1:
xi = xi.squeeze()

return xi


def toeplitz_inverse(toeplitz_column, toeplitz_row):
"""
Calculate the Toeplitz matrix inverse following the Trench algorithm.
(See: Shalhav Zohar (1969) - Toeplitz Matrix Inversion: The Algorithm of W. F. Trench)
Args:
- toeplitz_column (vector n or b x n) - First column of the Toeplitz matrix T.
- toeplitz_row (vector n or b x n) - First row of the Toeplitz matrix T.
Returns:
- tensor (m x m or s x m x m) - The inverse of the Toeplitz matrices.
"""
# Algorithm taken from:
# https://dl.acm.org/doi/pdf/10.1145/321541.321549
if toeplitz_column[0] == 0.:
raise ValueError("The main diagonal term (i.e. first column and row element) must be non-zero")
raise NotImplementedError("To be implemented")
return inv


def sym_toeplitz_derivative_quadratic_form(left_vectors, right_vectors):
r"""
Given a left vector v1 and a right vector v2, computes the quadratic form:
Expand Down Expand Up @@ -201,3 +298,64 @@ def sym_toeplitz_derivative_quadratic_form(left_vectors, right_vectors):
res[..., 0] -= (left_vectors * right_vectors).view(*batch_shape, -1).sum(-1)

return res


def toeplitz_derivative_quadratic_form(left_vectors, right_vectors):
r"""
Given a left vector v1 and a right vector v2, computes the quadratic form:
v1'*(dT/dc_i)*v2
for all i, where dT/dc_i is the derivative of the Toeplitz matrix with respect to
the ith element of its first column. Note that dT/dc_i is the same for any
Toeplitz matrix T, so we do not require it as an argument.

In particular, dT/dc_i for i=-m..0..m is the matrix with ones on the ith sub- (i>0) and superdiagonal (i<0).

Args:
- left_vectors (vector m or matrix s x m) - s left vectors u[j] in the quadratic form.
- right_vectors (vector m or matrix s x m) - s right vectors v[j] in the quadratic form.
Returns:
- vector d_column - a vector so that the ith element is the result of \sum_j(u[j]*(dT/dc_i)*v[j])
(i<0, corresponding to derivative relative to column entries)
- vector d_row - a vector so that the ith element is the result of \sum_j(u[j]*(dT/dc_i)*v[j]) (i>0)
(i>0, corresponding to derivative relative to row entries)
"""
if left_vectors.ndimension() == 1:
left_vectors = left_vectors.unsqueeze(1)
right_vectors = right_vectors.unsqueeze(1)

batch_shape = left_vectors.shape[:-2]
toeplitz_size = left_vectors.size(-2)
num_vectors = left_vectors.size(-1)

left_vectors = left_vectors.mT.contiguous()
right_vectors = right_vectors.mT.contiguous()

columns = torch.zeros_like(left_vectors)
columns[..., 0] = left_vectors[..., 0]

res_r = toeplitz_matmul(columns, left_vectors, right_vectors.unsqueeze(-1))
rows = left_vectors.flip(dims=(-1,))
columns[..., 0] = rows[..., 0]
res_c = toeplitz_matmul(columns, rows, torch.flip(right_vectors, dims=(-1,)).unsqueeze(-1))

res_c = res_c.reshape(*batch_shape, num_vectors, toeplitz_size).sum(-2)
res_r = res_r.reshape(*batch_shape, num_vectors, toeplitz_size).sum(-2)

return [res_c, res_r]


def toeplitz_input_check(toeplitz_column, toeplitz_row):
"""
Helper routine to check if the input Toeplitz matrix is well defined.
"""
if toeplitz_column.size() != toeplitz_row.size():
raise RuntimeError("c and r should have the same length (Toeplitz matrices are necessarily square).")
if not torch.equal(toeplitz_column[..., 0], toeplitz_row[..., 0]):
raise RuntimeError(
"The first column and first row of the Toeplitz matrix should have "
"the same first element, otherwise the value of T[0,0] is ambiguous. "
"Got: c[0]={} and r[0]={}".format(toeplitz_column[0], toeplitz_row[0])
)
if type(toeplitz_column) != type(toeplitz_row):
raise RuntimeError("toeplitz_column and toeplitz_row should be the same type.")
return True
Loading