diff --git a/pySuStaIn/AbstractSustain.py b/pySuStaIn/AbstractSustain.py index c2379db..b2aa2a1 100644 --- a/pySuStaIn/AbstractSustain.py +++ b/pySuStaIn/AbstractSustain.py @@ -264,16 +264,15 @@ def cross_validate_sustain_model(self, test_idxs, select_fold = [], plot=False): if select_fold != []: if np.isscalar(select_fold): select_fold = [select_fold] - Nfolds = len(select_fold) else: select_fold = np.arange(len(test_idxs)) #test_idxs - Nfolds = len(test_idxs) + Nfolds = len(select_fold) is_full = Nfolds == len(test_idxs) loglike_matrix = np.zeros((Nfolds, self.N_S_max)) - for fold in tqdm(range(Nfolds), "Folds: ", Nfolds, position=0, leave=True): + for fold in tqdm(select_fold, "Folds: ", Nfolds, position=0, leave=True): indx_test = test_idxs[fold] indx_train = np.array([x for x in range(self.__sustainData.getNumSamples()) if x not in indx_test])