Skip to content

Commit

Permalink
removed reduced rank algorithm from main branch (see #54)
Browse files Browse the repository at this point in the history
  • Loading branch information
luigibonati committed May 2, 2023
1 parent 5be748c commit 733a888
Showing 1 changed file with 6 additions and 55 deletions.
61 changes: 6 additions & 55 deletions mlcolvar/core/stats/tica.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def extra_repr(self) -> str:
repr = f"in_features={self.in_features}, out_features={self.out_features}"
return repr

def compute(self, data, weights = None, remove_average=True, save_params=False, algorithm = 'least_squares'):
def compute(self, data, weights = None, remove_average=True, save_params=False):
"""Perform TICA computation.
Parameters
Expand All @@ -54,17 +54,12 @@ def compute(self, data, weights = None, remove_average=True, save_params=False,
whether to make the inputs mean free, by default True
save_params : bool, optional
Save parameters of estimator, by default False
algorithm : str, optional
Algorithm to use, by default 'least_squares'. Options are 'least_squares' and 'reduced_rank'. Both algorithms are described in [1]_.
Returns
-------
torch.Tensor
Eigenvalues
torch.Tensor,torch.Tensor
eigenvalues,eigenvectors
References
----------
.. [1] V. Kostic, P. Novelli, A. Maurer, C. Ciliberto, L. Rosasco, and M. Pontil, "Learning Dynamical Systems via Koopman Operator Regression in Reproducing Kernel Hilbert Spaces" (2022).
"""
# parse args

Expand All @@ -81,16 +76,7 @@ def compute(self, data, weights = None, remove_average=True, save_params=False,
C_0 = correlation_matrix(x_t,x_t,w_t)
C_lag = correlation_matrix(x_t,x_lag,w_lag)

if (algorithm == 'reduced_rank') and (self.out_features >= self.in_features):
warnings.warn('out_features is greater or equal than in_features. reduced_rank is equal to least_squares.')
algorithm = 'least_squares'

if algorithm == 'reduced_rank':
evals, evecs = reduced_rank_eig(C_0, C_lag, self.reg_C_0, rank = self.out_features)
elif algorithm != 'least_squares':
raise ValueError(f'algorithm {algorithm} not recognized. Options are least_squares and reduced_rank.')
else:
evals, evecs = cholesky_eigh(C_lag,C_0,self.reg_C_0,n_eig=self.out_features)
evals, evecs = cholesky_eigh(C_lag,C_0,self.reg_C_0,n_eig=self.out_features)

if save_params:
self.evals = evals
Expand Down Expand Up @@ -176,40 +162,5 @@ def test_tica():
s2 = tica(X2)
print(X2.shape,'-->',s2.shape)

def test_reduced_rank_tica():
in_features = 10
X = torch.rand(100,in_features)*100
x_t = X[:-1]
x_lag = X[1:]
w_t = torch.rand(len(x_t))
w_lag = w_t

# direct way, compute tica function
tica = TICA(in_features,out_features=5)
print(tica)
tica.compute([x_t,x_lag],[w_t,w_lag], save_params=True, algorithm='reduced_rank')
s = tica(X)
print(X.shape,'-->',s.shape)
print('eigvals',tica.evals)
print('timescales', tica.timescales(lag=10))

# step by step
tica = TICA(in_features)
C_0 = correlation_matrix(x_t,x_t)
C_lag = correlation_matrix(x_t,x_lag)
print(C_0.shape,C_lag.shape)

evals, evecs = reduced_rank_eig(C_0, C_lag, 1e-6, rank = 5)
print(evals.shape,evecs.shape)

print('>> batch')
s = tica(X)
print(X.shape,'-->',s.shape)
print('>> single')
X2 = X[0]
s2 = tica(X2)
print(X2.shape,'-->',s2.shape)

if __name__ == "__main__":
test_tica()
test_reduced_rank_tica()
test_tica()

0 comments on commit 733a888

Please sign in to comment.