From e608ac553a35caca06e6e6a471580a5b148fb12b Mon Sep 17 00:00:00 2001 From: Chenggong Date: Thu, 20 Jul 2023 10:56:07 +0200 Subject: [PATCH] new rate from pyemma --- Sfilter/util/MSM.py | 157 +++++++++++++++++++++++++++++++------------- test/test_MSM.py | 22 ++++++- 2 files changed, 132 insertions(+), 47 deletions(-) diff --git a/Sfilter/util/MSM.py b/Sfilter/util/MSM.py index c762248..b3de2b0 100644 --- a/Sfilter/util/MSM.py +++ b/Sfilter/util/MSM.py @@ -5,7 +5,7 @@ import networkx as nx import matplotlib.pyplot as plt from .output_wrapper import read_k_cylinder - +import pyemma @@ -171,17 +171,39 @@ def calc_state_array(self, merge_list=None): def get_transition_matrix(self, lag_step=1): """ - compute transition matrix + compute states transition matrix. a numpy array, each element is the number of transition from state i to state j """"" state_num = max([max(traj) for traj in self.state_array]) - tran_matrix = np.zeros((state_num + 1, state_num + 1), dtype=np.int64) + f_matrix = np.zeros((state_num + 1, state_num + 1), dtype=np.int64) + #for traj in self.state_array: + # state_start = traj[:-lag_step] + # state_end = traj[lag_step:] + # for m_step in np.array([state_start, state_end]).T: + # tran_matrix[m_step[0], m_step[1]] += 1 for traj in self.state_array: state_start = traj[:-lag_step] state_end = traj[lag_step:] for m_step in np.array([state_start, state_end]).T: - tran_matrix[m_step[0], m_step[1]] += 1 - return tran_matrix + f_matrix[m_step[0], m_step[1]] += 1 + + return f_matrix + + def f_matrix_2_rate_matrix(self, f_matrix, physical_time): + """ + compute rate matrix + a numpy array, each element is the rate from state i to state j + input: + f_matrix: a numpy array, each element is the number of transition from state i to state j + physical_time: physical time for each step + return: + rate_matrix: a numpy array, each element is the rate from state i to state j + """ + rate_matrix = np.array(f_matrix, dtype=np.float64) # convert int to float + for i in range(rate_matrix.shape[0]): + rate_matrix[i, :] /= self.node_counter[i] * physical_time + rate_matrix[i, i] = 0 + return rate_matrix def get_rate_matrix(self, lag_step=1, physical_time=None): """ @@ -200,25 +222,31 @@ def get_rate_matrix(self, lag_step=1, physical_time=None): else: raise ValueError("physical_time is not given, and time_step is not equal") - t_matrix = self.get_transition_matrix(lag_step) - rate_matrix = np.array(t_matrix, dtype=np.float64) - for i in range(rate_matrix.shape[0]): - rate_matrix[i, :] /= self.node_counter[i] * physical_time - rate_matrix[i, i] = 0 + f_matrix = self.get_transition_matrix(lag_step) + rate_matrix = self.f_matrix_2_rate_matrix(f_matrix, physical_time) return rate_matrix - def get_transition_probability(self, lag_step=1): + def f_matrix_2_transition_probability(self, f_matrix): """ - compute transition probability + compute transition probability matrix (between steps) return: transition_probability_matrix a numpy array, each element is the probability of transition from state i to state j The sum of each row is 1. """ - t_matrix = self.get_transition_matrix(lag_step) - p_matrix = np.array(t_matrix, dtype=np.float64) + p_matrix = np.array(f_matrix, dtype=np.float64) # int to float p_matrix /= p_matrix.sum(axis=1, keepdims=True) # normalize each row return p_matrix + def get_transition_probability(self, lag_step=1): + """ + compute transition probability matrix (between steps) + return: transition_probability_matrix + a numpy array, each element is the probability of transition from state i to state j + The sum of each row is 1. + """ + f_matrix = self.get_transition_matrix(lag_step) + return self.f_matrix_2_transition_probability(f_matrix) + def get_CK_test(self, lag_step=1, test_time=[2, 4]): """ run Chapman-Kolmogorov test @@ -303,8 +331,8 @@ def get_matrix(self, lag_step=1, physical_time=None): """ calculate transition matrix, rate_matrix, and transition probability return: - t_matrix: each element is the number of transition from state i to state j - net_t_matrix: each element is the number of net event between state i to state j (t_matrix - t_matrix.T) + f_matrix: each element is the number of flux from state i to state j + net_f_matrix: each element is the number of net event between state i to state j (f_ij - f_ji) rate_matrix: each element is the rate (number of event / observation time) from state i to state j p_matrix: each element is the probability of transition from state i to state j input: @@ -317,15 +345,11 @@ def get_matrix(self, lag_step=1, physical_time=None): else: raise ValueError("physical_time is not given, and time_step is not equal") - t_matrix = self.get_transition_matrix(lag_step) - net_t_matrix = t_matrix - t_matrix.T - rate_matrix = np.array(t_matrix, dtype=np.float64) - for i in range(rate_matrix.shape[0]): - rate_matrix[i, :] /= self.node_counter[i] * physical_time - rate_matrix[i, i] = 0 - p_matrix = np.array(t_matrix, dtype=np.float64) - p_matrix /= p_matrix.sum(axis=1, keepdims=True) - return t_matrix, net_t_matrix, rate_matrix, p_matrix + f_matrix = self.get_transition_matrix(lag_step) # transition between states (not steps) + net_t_matrix = f_matrix - f_matrix.T + rate_matrix = self.f_matrix_2_rate_matrix(f_matrix, physical_time) + p_matrix = self.f_matrix_2_transition_probability(f_matrix) + return f_matrix, net_t_matrix, rate_matrix, p_matrix def get_resident_time(self): """ @@ -356,7 +380,7 @@ def find_merge_states(self, cut_off=0.01, lag_step=1, physical_time=None, method """ if physical_time is None: # use the time_step that was read from file if np.allclose(self.time_step, self.time_step[0]): - physical_time = self.time_step[0] + physical_time = self.time_step[0] * lag_step else: raise ValueError("physical_time is not given, and time_step is not equal") @@ -446,6 +470,67 @@ def lump_MFPT(self, node_cut_off=0.01, min_node=3): raise ValueError("merge error " + str(n_0) + str(n_1)) return True, merge_list_new, [self.int_2_s[n_0], self.int_2_s[n_1]] + def get_pyemma_TPT_rate(self): + rate_matrix = np.zeros((len(self.int_2_s), len(self.int_2_s))) + msm = pyemma.msm.estimate_markov_model(self.state_array, lag=1, + reversible=False, dt_traj=str(self.time_step[0])+" ps") + for i in range(len(self.int_2_s)): + for j in range(len(self.int_2_s)): + if i != j: + rate_matrix[i, j] = pyemma.msm.tpt(msm, [i], [j]).rate + else: + rate_matrix[i, j] = 0 + return rate_matrix + + + def lump_pyemma_TPT_rate(self, node_cut_off=0.01, min_node=3): + total_count = self.node_counter.total() + for n, node_count in self.node_counter.items(): + if node_count / total_count < node_cut_off: + break + # check node number + if n < min_node: + return False, copy.deepcopy(self.merge_list), [0, 0] + # get rate matrix using pyemma TPT + msm_pyemma = pyemma.msm.estimate_markov_model(self.state_array, lag=1, + reversible=False, dt_traj=str(self.time_step[0])+" ps") + rate_matrix = np.zeros((n, n)) + for i in range(n): + for j in range(n): + if i != j: + rate_matrix[i, j] = pyemma.msm.tpt(msm_pyemma, [i], [j]).rate + else: + rate_matrix[i, j] = 0 + rate_list = [] + for i in range(n): + for j in range(0, i): + rate_list.append([i, j, rate_matrix[i, j] * rate_matrix[j, i]]) + rate_list = sorted(rate_list, key=lambda x: x[2], reverse=True) + n_0, n_1, rate = rate_list[0] + if len(self.int_2_s[n_0]) > 1 and len(self.int_2_s[n_1]) > 1: + merge_list_new = copy.deepcopy(self.merge_list) + for node in self.merge_list: + for node2 in self.merge_list: + if node == self.int_2_s[n_0] and node2 == self.int_2_s[n_1]: + merge_list_new.remove(node) + merge_list_new.remove(node2) + merge_list_new.append(node + node2) + elif len(self.int_2_s[n_0]) > 1 or len(self.int_2_s[n_1]) > 1: + merge_list_new = [] + for node in self.merge_list: + if node == self.int_2_s[n_0]: + merge_list_new.append(node + self.int_2_s[n_1]) + elif node == self.int_2_s[n_1]: + merge_list_new.append(self.int_2_s[n_0] + node) + else: + merge_list_new.append(node) + elif len(self.int_2_s[n_0]) == 1 or len(self.int_2_s[n_1]) == 1: + merge_list_new = copy.deepcopy(self.merge_list) + merge_list_new.append(self.int_2_s[n_0] + self.int_2_s[n_1]) + else: + raise ValueError("merge error " + str(n_0) + str(n_1)) + return True, merge_list_new, [self.int_2_s[n_0], self.int_2_s[n_1]] + def merge_until(self, rate_cut_off, rate_square_cut_off, node_cut_off=0.01, step_cut_off=30, lag_step=1, physical_time=None, @@ -635,25 +720,5 @@ def computer_pos(strings): return x, y -def get_transition_matrix(state_arrays, begin=0, lag_time=1): - state_num = max([max(traj) for traj in state_arrays]) - tran_matrix = np.zeros((state_num + 1, state_num + 1), dtype=np.int64) - for traj in state_arrays: - state_start = traj[begin:-lag_time] - state_end = traj[begin + lag_time:] - for m_step in np.array([state_start, state_end]).T: - tran_matrix[m_step[0], m_step[1]] += 1 - return tran_matrix - - -def get_distribution(state_arrays): - flattened = [num for sublist in state_arrays for num in sublist] - return Counter(flattened) -def get_rate_matrix(state_arrays, phy_time, begin=0, lag_time=1): - tran_matrix = np.array(get_transition_matrix(state_arrays, begin, lag_time), dtype=np.float64) - counter = get_distribution(state_arrays) - for i in range(tran_matrix.shape[0]): - tran_matrix[i, :] /= counter[i] * phy_time - return tran_matrix diff --git a/test/test_MSM.py b/test/test_MSM.py index 58190ab..d437b90 100644 --- a/test/test_MSM.py +++ b/test/test_MSM.py @@ -51,6 +51,26 @@ def test_SF_msm_set_state_str(self): self.assertDictEqual(msm.state_counter, {"A": 11, "B": 7, "C": 3}) self.assertDictEqual(msm.node_counter, {0: 18, 1: 3}) + def test_SF_msm_get_transition_matrix(self): + msm = MSM.SF_msm([]) + msm.set_state_str(["A A B C A B C A B C A B C A B".split()]) + msm.calc_state_array() + f_matrix_1 = msm.get_transition_matrix(lag_step=1) + f_matrix_2 = msm.get_transition_matrix(lag_step=2) + f_matrix_3 = msm.get_transition_matrix(lag_step=3) + self.assertListEqual(f_matrix_1.tolist(), [[1, 5, 0], [0, 0, 4], [4, 0, 0]]) + self.assertListEqual(f_matrix_2.tolist(), [[0, 1, 4], [4, 0, 0], [0, 4, 0]]) + self.assertListEqual(f_matrix_3.tolist(), [[4, 0, 1], [0, 4, 0], [0, 0, 3]]) + msm.time_step = [1] + r_1 = msm.get_rate_matrix(lag_step=1) + r_2 = msm.get_rate_matrix(lag_step=2) + r_3 = msm.get_rate_matrix(lag_step=3) + self.assertListEqual(r_1.tolist(), [[0, 5/6, 0], [0, 0, 4/5], [4/4, 0, 0]]) + self.assertListEqual(r_2.tolist(), [[0, 1/6, 4/6], [4/5, 0, 0], [0, 4/4, 0]]) + self.assertListEqual(r_3.tolist(), [[0, 0, 1/6], [0, 0, 0], [0, 0, 0]]) + + + def test_SF_msm_get_matrix(self): msm = MSM.SF_msm([]) msm.set_state_str(["A B A B A B C D".split(), @@ -145,7 +165,7 @@ def test_SF_msm_merge_until_03(self): "A B A B A B A B A B A D C D D C C D D".split(), "B A B A B A B A B A B D D C D C C D D E D".split()]) msm.calc_state_array() - reason = msm.merge_until(rate_cut_off=0.00, rate_square_cut_off=0.00, node_cut_off=0.017, lag_step=1, physical_time=1, method="rate_square", min_node=1) + reason = msm.merge_until(rate_cut_off=0.00, rate_square_cut_off=0.00, node_cut_off=0.017, lag_step=1, physical_time=1, method="rate_square", min_node=2) t_matrix, net_t_matrix, rate_matrix, p_matrix = msm.get_matrix(lag_step=1, physical_time=1) print() print(reason)