From a7d2e26ef559487dce9ae7e838a95f87d9c8ad97 Mon Sep 17 00:00:00 2001 From: Neil Oxtoby Date: Thu, 22 Sep 2022 11:41:23 +0100 Subject: [PATCH] I think I fixed cross_validate_sustain_model() (#38) --- pySuStaIn/AbstractSustain.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pySuStaIn/AbstractSustain.py b/pySuStaIn/AbstractSustain.py index c2379db..b2aa2a1 100644 --- a/pySuStaIn/AbstractSustain.py +++ b/pySuStaIn/AbstractSustain.py @@ -264,16 +264,15 @@ def cross_validate_sustain_model(self, test_idxs, select_fold = [], plot=False): if select_fold != []: if np.isscalar(select_fold): select_fold = [select_fold] - Nfolds = len(select_fold) else: select_fold = np.arange(len(test_idxs)) #test_idxs - Nfolds = len(test_idxs) + Nfolds = len(select_fold) is_full = Nfolds == len(test_idxs) loglike_matrix = np.zeros((Nfolds, self.N_S_max)) - for fold in tqdm(range(Nfolds), "Folds: ", Nfolds, position=0, leave=True): + for fold in tqdm(select_fold, "Folds: ", Nfolds, position=0, leave=True): indx_test = test_idxs[fold] indx_train = np.array([x for x in range(self.__sustainData.getNumSamples()) if x not in indx_test])