Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error with example fMOW command: incorrect value of "unlabeled_n_groups_per_batch" #152

Open
joshuafan opened this issue Oct 2, 2023 · 0 comments

Comments

@joshuafan
Copy link

Hello,
If I directly run this command suggested in the README:
python examples/run_expt.py --dataset fmow --algorithm DANN --unlabeled_split test_unlabeled --root_dir data

I get the following exeption:

Traceback (most recent call last):
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/run_expt.py", line 491, in <module>
    main()
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/run_expt.py", line 454, in main
    train(
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/train.py", line 114, in train
    run_epoch(algorithm, datasets['train'], general_logger, epoch, config, train=True, unlabeled_dataset=unlabeled_dataset)
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/train.py", line 38, in run_epoch
    unlabeled_data_iterator = InfiniteDataIterator(unlabeled_dataset['loader'])
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/utils.py", line 393, in __init__
    self.iter = iter(self.data_loader)
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 442, in __iter__
    return self._get_iterator()
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 388, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1085, in __init__
    self._reset(loader, first_iter=True)
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1118, in _reset
    self._try_put_index()
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1352, in _try_put_index
    index = self._next_index()
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 624, in _next_index
    return next(self._sampler_iter)  # may raise StopIteration
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/wilds/common/data_loaders.py", line 131, in __iter__
    groups_for_batch = np.random.choice(
  File "mtrand.pyx", line 984, in numpy.random.mtrand.RandomState.choice
ValueError: Cannot take a larger sample than population when 'replace=False'

I think this occurs because there are only 2 unique years in the test_unlabeled split, but unlabeled_n_groups_per_batch is set to 8, so it tries to sample 8 years without replacement.

I was able to fix this by changing the argument unlabeled_n_groups_per_batch to 2, here: https://github.com/p-lambda/wilds/blob/main/examples/configs/datasets.py#L220

It would be great if this can be fixed. Thank you so much for releasing these wonderful datasets and baseline algorithms!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant