Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IndexError during CV #59

Open
katrinaCode opened this issue Aug 8, 2024 · 1 comment
Open

IndexError during CV #59

katrinaCode opened this issue Aug 8, 2024 · 1 comment

Comments

@katrinaCode
Copy link

Hi all,

Wanted to submit a fix for an occasional error I get during CV.
The error is as follows:

Traceback (most recent call last):
  File "", line 176, in <module>
    CVIC, loglike_matrix     = sustain_input.cross_validate_sustain_model(test_idxs)
  File "AbstractSustain.py", line 294, in cross_validate_sustain_model
    sustainData_test                = self.__sustainData.reindex(indx_test)
  File "ZscoreSustain.py", line 54, in reindex
    return ZScoreSustainData(self.data[index,], self.__numStages)
IndexError: arrays used as indices must be of integer (or boolean) type

And my fix is simply to explicitly define index_test as an array of integers in line 277 of AbstractSustain:

indx_test                       = (test_idxs[fold]).astype(int)

Would be interested to hear any theories as to why this error happens irregularly; this will happen with some models but not others running on identical versions of my notebook. In that section of my notebook, I follow the SuStaIn workshop essentially verbatim:

labels = sustain_data[label_column].values
cv = sklearn.model_selection.StratifiedKFold(n_splits=N_folds, shuffle=True, random_state=3)
cv_it = cv.split(sustain_data, labels)

# SuStaIn currently accepts ragged arrays, which will raise problems in the future.
# We'll have to update this in the future, but this will have to do for now
test_idxs = []
for train, test in cv_it:
    test_idxs.append(test)
test_idxs = np.array(test_idxs,dtype='object')

for i, (train_index, test_index) in enumerate(cv.split(sustain_data, labels)):
  print(f"Fold {i}:")
  print(f"  Train: index={train_index}")
  print(f"  Test:  index={test_index}")
# perform cross-validation and output the cross-validation information criterion and
# log-likelihood on the test set for each subtypes model and fold combination
CVIC, loglike_matrix     = sustain_input.cross_validate_sustain_model(test_idxs)

Thanks 😊

@xullllllll
Copy link

I've never seen a error like that.Can I ask you a question?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants